From b9fa11c0c39236691d018d248d0391a8a1792a07 Mon Sep 17 00:00:00 2001 From: Junyan Qin <1010553892@qq.com> Date: Tue, 12 Mar 2024 16:22:07 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20prompt=20=E5=8A=A0=E8=BD=BD=E5=99=A8?= =?UTF-8?q?=E7=9A=84=E6=89=A9=E5=B1=95=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/provider/sysprompt/loader.py | 14 ++++++++++++++ pkg/provider/sysprompt/loaders/scenario.py | 1 + pkg/provider/sysprompt/loaders/single.py | 1 + pkg/provider/sysprompt/sysprompt.py | 12 +++++++----- 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/pkg/provider/sysprompt/loader.py b/pkg/provider/sysprompt/loader.py index ca9e8730..9e0a6144 100644 --- a/pkg/provider/sysprompt/loader.py +++ b/pkg/provider/sysprompt/loader.py @@ -1,13 +1,27 @@ from __future__ import annotations import abc +import typing from ...core import app from . import entities +preregistered_loaders: list[typing.Type[PromptLoader]] = [] + +def loader_class(name: str): + + def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]: + cls.name = name + preregistered_loaders.append(cls) + return cls + + return decorator + + class PromptLoader(metaclass=abc.ABCMeta): """Prompt加载器抽象类 """ + name: str ap: app.Application diff --git a/pkg/provider/sysprompt/loaders/scenario.py b/pkg/provider/sysprompt/loaders/scenario.py index a559ff73..9c19d963 100644 --- a/pkg/provider/sysprompt/loaders/scenario.py +++ b/pkg/provider/sysprompt/loaders/scenario.py @@ -8,6 +8,7 @@ from .. import entities from ....provider import entities as llm_entities +@loader.loader_class("full_scenario") class ScenarioPromptLoader(loader.PromptLoader): """加载scenario目录下的json""" diff --git a/pkg/provider/sysprompt/loaders/single.py b/pkg/provider/sysprompt/loaders/single.py index 57e06ed2..3ac9c262 100644 --- a/pkg/provider/sysprompt/loaders/single.py +++ b/pkg/provider/sysprompt/loaders/single.py @@ -6,6 +6,7 @@ from .. import entities from ....provider import entities as llm_entities +@loader.loader_class("normal") class SingleSystemPromptLoader(loader.PromptLoader): """配置文件中的单条system prompt的prompt加载器 """ diff --git a/pkg/provider/sysprompt/sysprompt.py b/pkg/provider/sysprompt/sysprompt.py index eb89e8ab..61c598ed 100644 --- a/pkg/provider/sysprompt/sysprompt.py +++ b/pkg/provider/sysprompt/sysprompt.py @@ -20,12 +20,14 @@ class PromptManager: async def initialize(self): - loader_map = { - "normal": single.SingleSystemPromptLoader, - "full_scenario": scenario.ScenarioPromptLoader - } + mode_name = self.ap.provider_cfg.data['prompt-mode'] - loader_cls = loader_map[self.ap.provider_cfg.data['prompt-mode']] + for loader_cls in loader.preregistered_loaders: + if loader_cls.name == mode_name: + loader_cls = loader_cls + break + else: + raise ValueError(f'未知的 Prompt 加载器: {mode_name}') self.loader_inst: loader.PromptLoader = loader_cls(self.ap)