From fb18278bdcf0c43aedaa519c93c8a8dbad6e5c16 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Thu, 3 Apr 2025 17:06:01 +0800 Subject: [PATCH] refactor: move prompt mgm to pipeline --- pkg/api/http/service/pipeline.py | 9 ++-- pkg/command/cmdmgr.py | 2 +- pkg/command/operators/default.py | 62 ---------------------- pkg/core/app.py | 8 --- pkg/core/entities.py | 5 +- pkg/core/stages/build_app.py | 5 -- pkg/pipeline/preproc/preproc.py | 2 +- pkg/provider/entities.py | 12 +++++ pkg/provider/session/sessionmgr.py | 16 +++++- pkg/provider/sysprompt/__init__.py | 0 pkg/provider/sysprompt/entities.py | 16 ------ pkg/provider/sysprompt/loader.py | 46 ---------------- pkg/provider/sysprompt/loaders/__init__.py | 0 pkg/provider/sysprompt/loaders/scenario.py | 39 -------------- pkg/provider/sysprompt/loaders/single.py | 43 --------------- pkg/provider/sysprompt/sysprompt.py | 56 ------------------- 16 files changed, 36 insertions(+), 285 deletions(-) delete mode 100644 pkg/command/operators/default.py delete mode 100644 pkg/provider/sysprompt/__init__.py delete mode 100644 pkg/provider/sysprompt/entities.py delete mode 100644 pkg/provider/sysprompt/loader.py delete mode 100644 pkg/provider/sysprompt/loaders/__init__.py delete mode 100644 pkg/provider/sysprompt/loaders/scenario.py delete mode 100644 pkg/provider/sysprompt/loaders/single.py delete mode 100644 pkg/provider/sysprompt/sysprompt.py diff --git a/pkg/api/http/service/pipeline.py b/pkg/api/http/service/pipeline.py index f1bcaa75..72b7daf7 100644 --- a/pkg/api/http/service/pipeline.py +++ b/pkg/api/http/service/pipeline.py @@ -79,9 +79,12 @@ class PipelineService: return pipeline_data['uuid'] async def update_pipeline(self, pipeline_uuid: str, pipeline_data: dict) -> None: - del pipeline_data['uuid'] - del pipeline_data['for_version'] - del pipeline_data['stages'] + if 'uuid' in pipeline_data: + del pipeline_data['uuid'] + if 'for_version' in pipeline_data: + del pipeline_data['for_version'] + if 'stages' in pipeline_data: + del pipeline_data['stages'] await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid).values(**pipeline_data) ) diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index 8d442fdb..ea4e1a9b 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -8,7 +8,7 @@ from . import entities, operator, errors from ..config import manager as cfg_mgr # 引入所有算子以便注册 -from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model +from .operators import func, plugin, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model class CommandManager: diff --git a/pkg/command/operators/default.py b/pkg/command/operators/default.py deleted file mode 100644 index ee46c7d0..00000000 --- a/pkg/command/operators/default.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -import typing -import traceback - -from .. import operator, entities, cmdmgr, errors - - -@operator.operator_class( - name="default", - help="操作情景预设", - usage='!default\n!default set <指定情景预设为默认>' -) -class DefaultOperator(operator.CommandOperator): - - async def execute( - self, - context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - - reply_str = "当前所有情景预设: \n\n" - - for prompt in self.ap.prompt_mgr.get_all_prompts(): - - content = "" - for msg in prompt.messages: - content += f" {msg.readable_str()}\n" - - reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n" - - reply_str += f"当前会话使用的是: {context.session.use_prompt_name}" - - yield entities.CommandReturn(text=reply_str.strip()) - - -@operator.operator_class( - name="set", - help="设置当前会话默认情景预设", - parent_class=DefaultOperator -) -class DefaultSetOperator(operator.CommandOperator): - - async def execute( - self, - context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - - if len(context.crt_params) == 0: - yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称')) - else: - prompt_name = context.crt_params[0] - - try: - prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name) - if prompt is None: - yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name))) - else: - context.session.use_prompt_name = prompt.name - yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效") - except Exception as e: - traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e))) diff --git a/pkg/core/app.py b/pkg/core/app.py index 126b165a..bfea617b 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -11,7 +11,6 @@ import os from ..platform import manager as im_mgr from ..provider.session import sessionmgr as llm_session_mgr from ..provider.modelmgr import modelmgr as llm_model_mgr -from ..provider.sysprompt import sysprompt as llm_prompt_mgr from ..provider.tools import toolmgr as llm_tool_mgr from ..config import manager as config_mgr from ..config import settings as settings_mgr @@ -52,9 +51,6 @@ class Application: model_mgr: llm_model_mgr.ModelManager = None - # TODO 移动到 pipeline 里 - prompt_mgr: llm_prompt_mgr.PromptManager = None - # TODO 移动到 pipeline 里 tool_mgr: llm_tool_mgr.ToolManager = None @@ -229,10 +225,6 @@ class Application: await llm_session_mgr_inst.initialize() self.sess_mgr = llm_session_mgr_inst - llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(self) - await llm_prompt_mgr_inst.initialize() - self.prompt_mgr = llm_prompt_mgr_inst - llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self) await llm_tool_mgr_inst.initialize() self.tool_mgr = llm_tool_mgr_inst diff --git a/pkg/core/entities.py b/pkg/core/entities.py index d8768494..1753495b 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -9,7 +9,6 @@ import pydantic.v1 as pydantic from ..provider import entities as llm_entities from ..provider.modelmgr import entities, modelmgr, requester -from ..provider.sysprompt import entities as sysprompt_entities from ..provider.tools import entities as tools_entities from ..platform import adapter as msadapter from ..platform.types import message as platform_message @@ -75,7 +74,7 @@ class Query(pydantic.BaseModel): messages: typing.Optional[list[llm_entities.Message]] = [] """历史消息列表,由前置处理器阶段设置""" - prompt: typing.Optional[sysprompt_entities.Prompt] = None + prompt: typing.Optional[llm_entities.Prompt] = None """情景预设内容,由前置处理器阶段设置""" user_message: typing.Optional[llm_entities.Message] = None @@ -127,7 +126,7 @@ class Query(pydantic.BaseModel): class Conversation(pydantic.BaseModel): """对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation,但只有一个当前使用的 Conversation""" - prompt: sysprompt_entities.Prompt + prompt: llm_entities.Prompt messages: list[llm_entities.Message] diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index fcd930a3..fc049d9c 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -11,7 +11,6 @@ from ...plugin import manager as plugin_mgr from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr -from ...provider.sysprompt import sysprompt as llm_prompt_mgr from ...provider.tools import toolmgr as llm_tool_mgr from ...platform import manager as im_mgr from ...persistence import mgr as persistencemgr @@ -95,10 +94,6 @@ class BuildAppStage(stage.BootingStage): await llm_session_mgr_inst.initialize() ap.sess_mgr = llm_session_mgr_inst - llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap) - await llm_prompt_mgr_inst.initialize() - ap.prompt_mgr = llm_prompt_mgr_inst - llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap) await llm_tool_mgr_inst.initialize() ap.tool_mgr = llm_tool_mgr_inst diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 9958466a..8cf51463 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -33,7 +33,7 @@ class PreProcessor(stage.PipelineStage): """ session = await self.ap.sess_mgr.get_session(query) - conversation = await self.ap.sess_mgr.get_conversation(query, session) + conversation = await self.ap.sess_mgr.get_conversation(query, session, query.pipeline_config) # 设置query query.session = session diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index dce55fd5..0fb75f80 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -4,6 +4,8 @@ import typing import enum import pydantic.v1 as pydantic +from pkg.provider import entities + from ..platform.types import message as platform_message @@ -124,3 +126,13 @@ class Message(pydantic.BaseModel): mc.insert(0, platform_message.Plain(prefix_text)) return platform_message.MessageChain(mc) + + +class Prompt(pydantic.BaseModel): + """供AI使用的Prompt""" + + name: str + """名称""" + + messages: list[entities.Message] + """消息列表""" diff --git a/pkg/provider/session/sessionmgr.py b/pkg/provider/session/sessionmgr.py index 83691e4c..93b1146e 100644 --- a/pkg/provider/session/sessionmgr.py +++ b/pkg/provider/session/sessionmgr.py @@ -4,6 +4,7 @@ import asyncio from ...core import app, entities as core_entities from ...plugin import context as plugin_context +from ...provider import entities as provider_entities class SessionManager: @@ -41,15 +42,26 @@ class SessionManager: self.session_list.append(session) return session - async def get_conversation(self, query: core_entities.Query, session: core_entities.Session) -> core_entities.Conversation: + async def get_conversation(self, query: core_entities.Query, session: core_entities.Session, pipeline_config: dict) -> core_entities.Conversation: """获取对话或创建对话""" if not session.conversations: session.conversations = [] + # set prompt + prompt_messages = [] + + for prompt_message in pipeline_config['ai']['local-agent']['prompt']: + prompt_messages.append(provider_entities.Message(**prompt_message)) + + prompt = provider_entities.Prompt( + name="default", + messages=prompt_messages, + ) + if session.using_conversation is None: conversation = core_entities.Conversation( - prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name), + prompt=prompt, messages=[], use_llm_model=await self.ap.model_mgr.get_model_by_uuid( query.pipeline_config['ai']['local-agent']['model'] diff --git a/pkg/provider/sysprompt/__init__.py b/pkg/provider/sysprompt/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/provider/sysprompt/entities.py b/pkg/provider/sysprompt/entities.py deleted file mode 100644 index 5442e809..00000000 --- a/pkg/provider/sysprompt/entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import annotations - -import typing -import pydantic.v1 as pydantic - -from ...provider import entities - - -class Prompt(pydantic.BaseModel): - """供AI使用的Prompt""" - - name: str - """名称""" - - messages: list[entities.Message] - """消息列表""" diff --git a/pkg/provider/sysprompt/loader.py b/pkg/provider/sysprompt/loader.py deleted file mode 100644 index 855728e2..00000000 --- a/pkg/provider/sysprompt/loader.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations -import abc -import typing - -from ...core import app -from . import entities - - -preregistered_loaders: list[typing.Type[PromptLoader]] = [] - -def loader_class(name: str): - - def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]: - cls.name = name - preregistered_loaders.append(cls) - return cls - - return decorator - - -class PromptLoader(metaclass=abc.ABCMeta): - """Prompt加载器抽象类 - """ - name: str - - ap: app.Application - - prompts: list[entities.Prompt] - - def __init__(self, ap: app.Application): - self.ap = ap - self.prompts = [] - - async def initialize(self): - pass - - @abc.abstractmethod - async def load(self): - """加载Prompt,存放到prompts列表中 - """ - raise NotImplementedError - - def get_prompts(self) -> list[entities.Prompt]: - """获取Prompt列表 - """ - return self.prompts diff --git a/pkg/provider/sysprompt/loaders/__init__.py b/pkg/provider/sysprompt/loaders/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/provider/sysprompt/loaders/scenario.py b/pkg/provider/sysprompt/loaders/scenario.py deleted file mode 100644 index f907a51c..00000000 --- a/pkg/provider/sysprompt/loaders/scenario.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -import json -import os - -from .. import loader -from .. import entities -from ....provider import entities as llm_entities - - -@loader.loader_class("full-scenario") -class ScenarioPromptLoader(loader.PromptLoader): - """加载scenario目录下的json""" - - async def load(self): - """加载Prompt - """ - for file in os.listdir("data/scenario"): - with open("data/scenario/{}".format(file), "r", encoding="utf-8") as f: - file_str = f.read() - file_name = file.split(".")[0] - file_json = json.loads(file_str) - messages = [] - for msg in file_json["prompt"]: - role = 'system' - if "role" in msg: - role = msg['role'] - messages.append( - llm_entities.Message( - role=role, - content=msg['content'], - ) - ) - prompt = entities.Prompt( - name=file_name, - messages=messages - ) - self.prompts.append(prompt) - \ No newline at end of file diff --git a/pkg/provider/sysprompt/loaders/single.py b/pkg/provider/sysprompt/loaders/single.py deleted file mode 100644 index 3ac9c262..00000000 --- a/pkg/provider/sysprompt/loaders/single.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations -import os - -from .. import loader -from .. import entities -from ....provider import entities as llm_entities - - -@loader.loader_class("normal") -class SingleSystemPromptLoader(loader.PromptLoader): - """配置文件中的单条system prompt的prompt加载器 - """ - - async def load(self): - """加载Prompt - """ - - for name, cnt in self.ap.provider_cfg.data['prompt'].items(): - prompt = entities.Prompt( - name=name, - messages=[ - llm_entities.Message( - role='system', - content=cnt - ) - ] - ) - self.prompts.append(prompt) - - for file in os.listdir("data/prompts"): - with open("data/prompts/{}".format(file), "r", encoding="utf-8") as f: - file_str = f.read() - file_name = file.split(".")[0] - prompt = entities.Prompt( - name=file_name, - messages=[ - llm_entities.Message( - role='system', - content=file_str - ) - ] - ) - self.prompts.append(prompt) diff --git a/pkg/provider/sysprompt/sysprompt.py b/pkg/provider/sysprompt/sysprompt.py deleted file mode 100644 index c7695f5a..00000000 --- a/pkg/provider/sysprompt/sysprompt.py +++ /dev/null @@ -1,56 +0,0 @@ -from __future__ import annotations - -from ...core import app -from . import loader -from .loaders import single, scenario - - -class PromptManager: - """Prompt管理器 - """ - - ap: app.Application - - loader_inst: loader.PromptLoader - - default_prompt: str = 'default' - - def __init__(self, ap: app.Application): - self.ap = ap - - async def initialize(self): - - mode_name = self.ap.provider_cfg.data['prompt-mode'] - - loader_class = None - - for loader_cls in loader.preregistered_loaders: - if loader_cls.name == mode_name: - loader_class = loader_cls - break - else: - raise ValueError(f'未知的 Prompt 加载器: {mode_name}') - - self.loader_inst: loader.PromptLoader = loader_class(self.ap) - - await self.loader_inst.initialize() - await self.loader_inst.load() - - def get_all_prompts(self) -> list[loader.entities.Prompt]: - """获取所有Prompt - """ - return self.loader_inst.get_prompts() - - async def get_prompt(self, name: str) -> loader.entities.Prompt: - """获取Prompt - """ - for prompt in self.get_all_prompts(): - if prompt.name == name: - return prompt - - async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt: - """通过前缀获取Prompt - """ - for prompt in self.get_all_prompts(): - if prompt.name.startswith(prefix): - return prompt