refactor: move prompt mgm to pipeline

This commit is contained in:
Junyan Qin
2025-04-03 17:06:01 +08:00
parent 913e43d84c
commit fb18278bdc
16 changed files with 36 additions and 285 deletions

View File

@@ -79,9 +79,12 @@ class PipelineService:
return pipeline_data['uuid']
async def update_pipeline(self, pipeline_uuid: str, pipeline_data: dict) -> None:
del pipeline_data['uuid']
del pipeline_data['for_version']
del pipeline_data['stages']
if 'uuid' in pipeline_data:
del pipeline_data['uuid']
if 'for_version' in pipeline_data:
del pipeline_data['for_version']
if 'stages' in pipeline_data:
del pipeline_data['stages']
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid).values(**pipeline_data)
)

View File

@@ -8,7 +8,7 @@ from . import entities, operator, errors
from ..config import manager as cfg_mgr
# 引入所有算子以便注册
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model
from .operators import func, plugin, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model
class CommandManager:

View File

@@ -1,62 +0,0 @@
from __future__ import annotations
import typing
import traceback
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="default",
help="操作情景预设",
usage='!default\n!default set <指定情景预设为默认>'
)
class DefaultOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = "当前所有情景预设: \n\n"
for prompt in self.ap.prompt_mgr.get_all_prompts():
content = ""
for msg in prompt.messages:
content += f" {msg.readable_str()}\n"
reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n"
reply_str += f"当前会话使用的是: {context.session.use_prompt_name}"
yield entities.CommandReturn(text=reply_str.strip())
@operator.operator_class(
name="set",
help="设置当前会话默认情景预设",
parent_class=DefaultOperator
)
class DefaultSetOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
else:
prompt_name = context.crt_params[0]
try:
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
if prompt is None:
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
else:
context.session.use_prompt_name = prompt.name
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))

View File

@@ -11,7 +11,6 @@ import os
from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.modelmgr import modelmgr as llm_model_mgr
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
from ..provider.tools import toolmgr as llm_tool_mgr
from ..config import manager as config_mgr
from ..config import settings as settings_mgr
@@ -52,9 +51,6 @@ class Application:
model_mgr: llm_model_mgr.ModelManager = None
# TODO 移动到 pipeline 里
prompt_mgr: llm_prompt_mgr.PromptManager = None
# TODO 移动到 pipeline 里
tool_mgr: llm_tool_mgr.ToolManager = None
@@ -229,10 +225,6 @@ class Application:
await llm_session_mgr_inst.initialize()
self.sess_mgr = llm_session_mgr_inst
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(self)
await llm_prompt_mgr_inst.initialize()
self.prompt_mgr = llm_prompt_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self)
await llm_tool_mgr_inst.initialize()
self.tool_mgr = llm_tool_mgr_inst

View File

@@ -9,7 +9,6 @@ import pydantic.v1 as pydantic
from ..provider import entities as llm_entities
from ..provider.modelmgr import entities, modelmgr, requester
from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
@@ -75,7 +74,7 @@ class Query(pydantic.BaseModel):
messages: typing.Optional[list[llm_entities.Message]] = []
"""历史消息列表,由前置处理器阶段设置"""
prompt: typing.Optional[sysprompt_entities.Prompt] = None
prompt: typing.Optional[llm_entities.Prompt] = None
"""情景预设内容,由前置处理器阶段设置"""
user_message: typing.Optional[llm_entities.Message] = None
@@ -127,7 +126,7 @@ class Query(pydantic.BaseModel):
class Conversation(pydantic.BaseModel):
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation但只有一个当前使用的 Conversation"""
prompt: sysprompt_entities.Prompt
prompt: llm_entities.Prompt
messages: list[llm_entities.Message]

View File

@@ -11,7 +11,6 @@ from ...plugin import manager as plugin_mgr
from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.modelmgr import modelmgr as llm_model_mgr
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr
from ...platform import manager as im_mgr
from ...persistence import mgr as persistencemgr
@@ -95,10 +94,6 @@ class BuildAppStage(stage.BootingStage):
await llm_session_mgr_inst.initialize()
ap.sess_mgr = llm_session_mgr_inst
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
await llm_prompt_mgr_inst.initialize()
ap.prompt_mgr = llm_prompt_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst

View File

@@ -33,7 +33,7 @@ class PreProcessor(stage.PipelineStage):
"""
session = await self.ap.sess_mgr.get_session(query)
conversation = await self.ap.sess_mgr.get_conversation(query, session)
conversation = await self.ap.sess_mgr.get_conversation(query, session, query.pipeline_config)
# 设置query
query.session = session

View File

@@ -4,6 +4,8 @@ import typing
import enum
import pydantic.v1 as pydantic
from pkg.provider import entities
from ..platform.types import message as platform_message
@@ -124,3 +126,13 @@ class Message(pydantic.BaseModel):
mc.insert(0, platform_message.Plain(prefix_text))
return platform_message.MessageChain(mc)
class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt"""
name: str
"""名称"""
messages: list[entities.Message]
"""消息列表"""

View File

@@ -4,6 +4,7 @@ import asyncio
from ...core import app, entities as core_entities
from ...plugin import context as plugin_context
from ...provider import entities as provider_entities
class SessionManager:
@@ -41,15 +42,26 @@ class SessionManager:
self.session_list.append(session)
return session
async def get_conversation(self, query: core_entities.Query, session: core_entities.Session) -> core_entities.Conversation:
async def get_conversation(self, query: core_entities.Query, session: core_entities.Session, pipeline_config: dict) -> core_entities.Conversation:
"""获取对话或创建对话"""
if not session.conversations:
session.conversations = []
# set prompt
prompt_messages = []
for prompt_message in pipeline_config['ai']['local-agent']['prompt']:
prompt_messages.append(provider_entities.Message(**prompt_message))
prompt = provider_entities.Prompt(
name="default",
messages=prompt_messages,
)
if session.using_conversation is None:
conversation = core_entities.Conversation(
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
prompt=prompt,
messages=[],
use_llm_model=await self.ap.model_mgr.get_model_by_uuid(
query.pipeline_config['ai']['local-agent']['model']

View File

@@ -1,16 +0,0 @@
from __future__ import annotations
import typing
import pydantic.v1 as pydantic
from ...provider import entities
class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt"""
name: str
"""名称"""
messages: list[entities.Message]
"""消息列表"""

View File

@@ -1,46 +0,0 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from . import entities
preregistered_loaders: list[typing.Type[PromptLoader]] = []
def loader_class(name: str):
def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]:
cls.name = name
preregistered_loaders.append(cls)
return cls
return decorator
class PromptLoader(metaclass=abc.ABCMeta):
"""Prompt加载器抽象类
"""
name: str
ap: app.Application
prompts: list[entities.Prompt]
def __init__(self, ap: app.Application):
self.ap = ap
self.prompts = []
async def initialize(self):
pass
@abc.abstractmethod
async def load(self):
"""加载Prompt存放到prompts列表中
"""
raise NotImplementedError
def get_prompts(self) -> list[entities.Prompt]:
"""获取Prompt列表
"""
return self.prompts

View File

@@ -1,39 +0,0 @@
from __future__ import annotations
import json
import os
from .. import loader
from .. import entities
from ....provider import entities as llm_entities
@loader.loader_class("full-scenario")
class ScenarioPromptLoader(loader.PromptLoader):
"""加载scenario目录下的json"""
async def load(self):
"""加载Prompt
"""
for file in os.listdir("data/scenario"):
with open("data/scenario/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
file_json = json.loads(file_str)
messages = []
for msg in file_json["prompt"]:
role = 'system'
if "role" in msg:
role = msg['role']
messages.append(
llm_entities.Message(
role=role,
content=msg['content'],
)
)
prompt = entities.Prompt(
name=file_name,
messages=messages
)
self.prompts.append(prompt)

View File

@@ -1,43 +0,0 @@
from __future__ import annotations
import os
from .. import loader
from .. import entities
from ....provider import entities as llm_entities
@loader.loader_class("normal")
class SingleSystemPromptLoader(loader.PromptLoader):
"""配置文件中的单条system prompt的prompt加载器
"""
async def load(self):
"""加载Prompt
"""
for name, cnt in self.ap.provider_cfg.data['prompt'].items():
prompt = entities.Prompt(
name=name,
messages=[
llm_entities.Message(
role='system',
content=cnt
)
]
)
self.prompts.append(prompt)
for file in os.listdir("data/prompts"):
with open("data/prompts/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
prompt = entities.Prompt(
name=file_name,
messages=[
llm_entities.Message(
role='system',
content=file_str
)
]
)
self.prompts.append(prompt)

View File

@@ -1,56 +0,0 @@
from __future__ import annotations
from ...core import app
from . import loader
from .loaders import single, scenario
class PromptManager:
"""Prompt管理器
"""
ap: app.Application
loader_inst: loader.PromptLoader
default_prompt: str = 'default'
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
mode_name = self.ap.provider_cfg.data['prompt-mode']
loader_class = None
for loader_cls in loader.preregistered_loaders:
if loader_cls.name == mode_name:
loader_class = loader_cls
break
else:
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
self.loader_inst: loader.PromptLoader = loader_class(self.ap)
await self.loader_inst.initialize()
await self.loader_inst.load()
def get_all_prompts(self) -> list[loader.entities.Prompt]:
"""获取所有Prompt
"""
return self.loader_inst.get_prompts()
async def get_prompt(self, name: str) -> loader.entities.Prompt:
"""获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name == name:
return prompt
async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt:
"""通过前缀获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name.startswith(prefix):
return prompt