From c8eb2e3376becfac45f508d03b54e023d56015e6 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Wed, 29 May 2024 20:34:49 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B6=88=E6=81=AF=E6=88=AA=E6=96=AD?= =?UTF-8?q?=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../migrations/m009_msg_truncator_cfg.py | 24 ++++++++ pkg/core/stages/migrate.py | 2 +- pkg/pipeline/msgtrun/__init__.py | 0 pkg/pipeline/msgtrun/msgtrun.py | 35 ++++++++++++ pkg/pipeline/msgtrun/truncator.py | 56 +++++++++++++++++++ pkg/pipeline/msgtrun/truncators/__init__.py | 0 pkg/pipeline/msgtrun/truncators/round.py | 32 +++++++++++ pkg/pipeline/stagemgr.py | 2 + templates/pipeline.json | 6 ++ 9 files changed, 156 insertions(+), 1 deletion(-) create mode 100644 pkg/config/migrations/m009_msg_truncator_cfg.py create mode 100644 pkg/pipeline/msgtrun/__init__.py create mode 100644 pkg/pipeline/msgtrun/msgtrun.py create mode 100644 pkg/pipeline/msgtrun/truncator.py create mode 100644 pkg/pipeline/msgtrun/truncators/__init__.py create mode 100644 pkg/pipeline/msgtrun/truncators/round.py diff --git a/pkg/config/migrations/m009_msg_truncator_cfg.py b/pkg/config/migrations/m009_msg_truncator_cfg.py new file mode 100644 index 00000000..369b60eb --- /dev/null +++ b/pkg/config/migrations/m009_msg_truncator_cfg.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from .. import migration + + +@migration.migration_class("msg-truncator-cfg-migration", 9) +class MsgTruncatorConfigMigration(migration.Migration): + """迁移""" + + async def need_migrate(self) -> bool: + """判断当前环境是否需要运行此迁移""" + return 'msg-truncate' not in self.ap.pipeline_cfg.data + + async def run(self): + """执行迁移""" + + self.ap.pipeline_cfg.data['msg-truncate'] = { + 'method': 'round', + 'round': { + 'max-round': 10 + } + } + + await self.ap.pipeline_cfg.dump_config() diff --git a/pkg/core/stages/migrate.py b/pkg/core/stages/migrate.py index 68a87d4e..2ad1e974 100644 --- a/pkg/core/stages/migrate.py +++ b/pkg/core/stages/migrate.py @@ -5,7 +5,7 @@ import importlib from .. import stage, app from ...config import migration from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion -from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate +from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg @stage.stage_class("MigrationStage") diff --git a/pkg/pipeline/msgtrun/__init__.py b/pkg/pipeline/msgtrun/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/msgtrun/msgtrun.py b/pkg/pipeline/msgtrun/msgtrun.py new file mode 100644 index 00000000..e56c551f --- /dev/null +++ b/pkg/pipeline/msgtrun/msgtrun.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from . import truncator +from .truncators import round + + +@stage.stage_class("ConversationMessageTruncator") +class ConversationMessageTruncator(stage.PipelineStage): + """会话消息截断器 + + 用于截断会话消息链,以适应平台消息长度限制。 + """ + trun: truncator.Truncator + + async def initialize(self): + use_method = self.ap.pipeline_cfg.data['msg-truncate']['method'] + + for trun in truncator.preregistered_truncators: + if trun.name == use_method: + self.trun = trun(self.ap) + break + else: + raise ValueError(f"未知的截断器: {use_method}") + + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + """处理 + """ + query = await self.trun.truncate(query) + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) \ No newline at end of file diff --git a/pkg/pipeline/msgtrun/truncator.py b/pkg/pipeline/msgtrun/truncator.py new file mode 100644 index 00000000..4afaf9fb --- /dev/null +++ b/pkg/pipeline/msgtrun/truncator.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import typing +import abc + +from ...core import entities as core_entities, app + + +preregistered_truncators: list[typing.Type[Truncator]] = [] + + +def truncator_class( + name: str +) -> typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]: + """截断器类装饰器 + + Args: + name (str): 截断器名称 + + Returns: + typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]: 装饰器 + """ + def decorator(cls: typing.Type[Truncator]) -> typing.Type[Truncator]: + assert issubclass(cls, Truncator) + + cls.name = name + + preregistered_truncators.append(cls) + + return cls + + return decorator + + +class Truncator(abc.ABC): + """消息截断器基类 + """ + + name: str + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def truncate(self, query: core_entities.Query) -> core_entities.Query: + """截断 + + 一般只需要操作query.messages,也可以扩展操作query.prompt, query.user_message。 + 请勿操作其他字段。 + """ + pass diff --git a/pkg/pipeline/msgtrun/truncators/__init__.py b/pkg/pipeline/msgtrun/truncators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/msgtrun/truncators/round.py b/pkg/pipeline/msgtrun/truncators/round.py new file mode 100644 index 00000000..646f2856 --- /dev/null +++ b/pkg/pipeline/msgtrun/truncators/round.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from .. import truncator +from ....core import entities as core_entities + + +@truncator.truncator_class("round") +class RoundTruncator(truncator.Truncator): + """前文回合数阶段器 + """ + + async def truncate(self, query: core_entities.Query) -> core_entities.Query: + """截断 + """ + max_round = self.ap.pipeline_cfg.data['msg-truncate']['round']['max-round'] + + temp_messages = [] + + current_round = 0 + + # 从后往前遍历 + for msg in query.messages[::-1]: + if current_round < max_round: + temp_messages.append(msg) + if msg.role == 'user': + current_round += 1 + else: + break + + query.messages = temp_messages[::-1] + + return query diff --git a/pkg/pipeline/stagemgr.py b/pkg/pipeline/stagemgr.py index 46957aad..fe3b4256 100644 --- a/pkg/pipeline/stagemgr.py +++ b/pkg/pipeline/stagemgr.py @@ -13,6 +13,7 @@ from .respback import respback from .wrapper import wrapper from .preproc import preproc from .ratelimit import ratelimit +from .msgtrun import msgtrun # 请求处理阶段顺序 @@ -21,6 +22,7 @@ stage_order = [ "BanSessionCheckStage", # 封禁会话检查 "PreContentFilterStage", # 内容过滤前置阶段 "PreProcessor", # 预处理器 + "ConversationMessageTruncator", # 会话消息截断器 "RequireRateLimitOccupancy", # 请求速率限制占用 "MessageProcessor", # 处理器 "ReleaseRateLimitOccupancy", # 释放速率限制占用 diff --git a/templates/pipeline.json b/templates/pipeline.json index aefc195c..ef2227ed 100644 --- a/templates/pipeline.json +++ b/templates/pipeline.json @@ -34,5 +34,11 @@ "limit": 60 } } + }, + "msg-truncate": { + "method": "round", + "round": { + "max-round": 10 + } } } \ No newline at end of file