mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-03 04:24:36 +00:00
refactor: move prompt mgm to pipeline
This commit is contained in:
@@ -79,9 +79,12 @@ class PipelineService:
|
||||
return pipeline_data['uuid']
|
||||
|
||||
async def update_pipeline(self, pipeline_uuid: str, pipeline_data: dict) -> None:
|
||||
del pipeline_data['uuid']
|
||||
del pipeline_data['for_version']
|
||||
del pipeline_data['stages']
|
||||
if 'uuid' in pipeline_data:
|
||||
del pipeline_data['uuid']
|
||||
if 'for_version' in pipeline_data:
|
||||
del pipeline_data['for_version']
|
||||
if 'stages' in pipeline_data:
|
||||
del pipeline_data['stages']
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid).values(**pipeline_data)
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from . import entities, operator, errors
|
||||
from ..config import manager as cfg_mgr
|
||||
|
||||
# 引入所有算子以便注册
|
||||
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model
|
||||
from .operators import func, plugin, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model
|
||||
|
||||
|
||||
class CommandManager:
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="default",
|
||||
help="操作情景预设",
|
||||
usage='!default\n!default set <指定情景预设为默认>'
|
||||
)
|
||||
class DefaultOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
reply_str = "当前所有情景预设: \n\n"
|
||||
|
||||
for prompt in self.ap.prompt_mgr.get_all_prompts():
|
||||
|
||||
content = ""
|
||||
for msg in prompt.messages:
|
||||
content += f" {msg.readable_str()}\n"
|
||||
|
||||
reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n"
|
||||
|
||||
reply_str += f"当前会话使用的是: {context.session.use_prompt_name}"
|
||||
|
||||
yield entities.CommandReturn(text=reply_str.strip())
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="set",
|
||||
help="设置当前会话默认情景预设",
|
||||
parent_class=DefaultOperator
|
||||
)
|
||||
class DefaultSetOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
|
||||
else:
|
||||
prompt_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
|
||||
if prompt is None:
|
||||
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
|
||||
else:
|
||||
context.session.use_prompt_name = prompt.name
|
||||
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))
|
||||
@@ -11,7 +11,6 @@ import os
|
||||
from ..platform import manager as im_mgr
|
||||
from ..provider.session import sessionmgr as llm_session_mgr
|
||||
from ..provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ..provider.tools import toolmgr as llm_tool_mgr
|
||||
from ..config import manager as config_mgr
|
||||
from ..config import settings as settings_mgr
|
||||
@@ -52,9 +51,6 @@ class Application:
|
||||
|
||||
model_mgr: llm_model_mgr.ModelManager = None
|
||||
|
||||
# TODO 移动到 pipeline 里
|
||||
prompt_mgr: llm_prompt_mgr.PromptManager = None
|
||||
|
||||
# TODO 移动到 pipeline 里
|
||||
tool_mgr: llm_tool_mgr.ToolManager = None
|
||||
|
||||
@@ -229,10 +225,6 @@ class Application:
|
||||
await llm_session_mgr_inst.initialize()
|
||||
self.sess_mgr = llm_session_mgr_inst
|
||||
|
||||
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(self)
|
||||
await llm_prompt_mgr_inst.initialize()
|
||||
self.prompt_mgr = llm_prompt_mgr_inst
|
||||
|
||||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self)
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
self.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
@@ -9,7 +9,6 @@ import pydantic.v1 as pydantic
|
||||
|
||||
from ..provider import entities as llm_entities
|
||||
from ..provider.modelmgr import entities, modelmgr, requester
|
||||
from ..provider.sysprompt import entities as sysprompt_entities
|
||||
from ..provider.tools import entities as tools_entities
|
||||
from ..platform import adapter as msadapter
|
||||
from ..platform.types import message as platform_message
|
||||
@@ -75,7 +74,7 @@ class Query(pydantic.BaseModel):
|
||||
messages: typing.Optional[list[llm_entities.Message]] = []
|
||||
"""历史消息列表,由前置处理器阶段设置"""
|
||||
|
||||
prompt: typing.Optional[sysprompt_entities.Prompt] = None
|
||||
prompt: typing.Optional[llm_entities.Prompt] = None
|
||||
"""情景预设内容,由前置处理器阶段设置"""
|
||||
|
||||
user_message: typing.Optional[llm_entities.Message] = None
|
||||
@@ -127,7 +126,7 @@ class Query(pydantic.BaseModel):
|
||||
class Conversation(pydantic.BaseModel):
|
||||
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation,但只有一个当前使用的 Conversation"""
|
||||
|
||||
prompt: sysprompt_entities.Prompt
|
||||
prompt: llm_entities.Prompt
|
||||
|
||||
messages: list[llm_entities.Message]
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from ...plugin import manager as plugin_mgr
|
||||
from ...command import cmdmgr
|
||||
from ...provider.session import sessionmgr as llm_session_mgr
|
||||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ...provider.tools import toolmgr as llm_tool_mgr
|
||||
from ...platform import manager as im_mgr
|
||||
from ...persistence import mgr as persistencemgr
|
||||
@@ -95,10 +94,6 @@ class BuildAppStage(stage.BootingStage):
|
||||
await llm_session_mgr_inst.initialize()
|
||||
ap.sess_mgr = llm_session_mgr_inst
|
||||
|
||||
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
|
||||
await llm_prompt_mgr_inst.initialize()
|
||||
ap.prompt_mgr = llm_prompt_mgr_inst
|
||||
|
||||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
ap.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
@@ -33,7 +33,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
"""
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
conversation = await self.ap.sess_mgr.get_conversation(query, session)
|
||||
conversation = await self.ap.sess_mgr.get_conversation(query, session, query.pipeline_config)
|
||||
|
||||
# 设置query
|
||||
query.session = session
|
||||
|
||||
@@ -4,6 +4,8 @@ import typing
|
||||
import enum
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from pkg.provider import entities
|
||||
|
||||
|
||||
from ..platform.types import message as platform_message
|
||||
|
||||
@@ -124,3 +126,13 @@ class Message(pydantic.BaseModel):
|
||||
mc.insert(0, platform_message.Plain(prefix_text))
|
||||
|
||||
return platform_message.MessageChain(mc)
|
||||
|
||||
|
||||
class Prompt(pydantic.BaseModel):
|
||||
"""供AI使用的Prompt"""
|
||||
|
||||
name: str
|
||||
"""名称"""
|
||||
|
||||
messages: list[entities.Message]
|
||||
"""消息列表"""
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from ...plugin import context as plugin_context
|
||||
from ...provider import entities as provider_entities
|
||||
|
||||
|
||||
class SessionManager:
|
||||
@@ -41,15 +42,26 @@ class SessionManager:
|
||||
self.session_list.append(session)
|
||||
return session
|
||||
|
||||
async def get_conversation(self, query: core_entities.Query, session: core_entities.Session) -> core_entities.Conversation:
|
||||
async def get_conversation(self, query: core_entities.Query, session: core_entities.Session, pipeline_config: dict) -> core_entities.Conversation:
|
||||
"""获取对话或创建对话"""
|
||||
|
||||
if not session.conversations:
|
||||
session.conversations = []
|
||||
|
||||
# set prompt
|
||||
prompt_messages = []
|
||||
|
||||
for prompt_message in pipeline_config['ai']['local-agent']['prompt']:
|
||||
prompt_messages.append(provider_entities.Message(**prompt_message))
|
||||
|
||||
prompt = provider_entities.Prompt(
|
||||
name="default",
|
||||
messages=prompt_messages,
|
||||
)
|
||||
|
||||
if session.using_conversation is None:
|
||||
conversation = core_entities.Conversation(
|
||||
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
|
||||
prompt=prompt,
|
||||
messages=[],
|
||||
use_llm_model=await self.ap.model_mgr.get_model_by_uuid(
|
||||
query.pipeline_config['ai']['local-agent']['model']
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from ...provider import entities
|
||||
|
||||
|
||||
class Prompt(pydantic.BaseModel):
|
||||
"""供AI使用的Prompt"""
|
||||
|
||||
name: str
|
||||
"""名称"""
|
||||
|
||||
messages: list[entities.Message]
|
||||
"""消息列表"""
|
||||
@@ -1,46 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
preregistered_loaders: list[typing.Type[PromptLoader]] = []
|
||||
|
||||
def loader_class(name: str):
|
||||
|
||||
def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]:
|
||||
cls.name = name
|
||||
preregistered_loaders.append(cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class PromptLoader(metaclass=abc.ABCMeta):
|
||||
"""Prompt加载器抽象类
|
||||
"""
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
prompts: list[entities.Prompt]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.prompts = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def load(self):
|
||||
"""加载Prompt,存放到prompts列表中
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_prompts(self) -> list[entities.Prompt]:
|
||||
"""获取Prompt列表
|
||||
"""
|
||||
return self.prompts
|
||||
@@ -1,39 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from .. import loader
|
||||
from .. import entities
|
||||
from ....provider import entities as llm_entities
|
||||
|
||||
|
||||
@loader.loader_class("full-scenario")
|
||||
class ScenarioPromptLoader(loader.PromptLoader):
|
||||
"""加载scenario目录下的json"""
|
||||
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
for file in os.listdir("data/scenario"):
|
||||
with open("data/scenario/{}".format(file), "r", encoding="utf-8") as f:
|
||||
file_str = f.read()
|
||||
file_name = file.split(".")[0]
|
||||
file_json = json.loads(file_str)
|
||||
messages = []
|
||||
for msg in file_json["prompt"]:
|
||||
role = 'system'
|
||||
if "role" in msg:
|
||||
role = msg['role']
|
||||
messages.append(
|
||||
llm_entities.Message(
|
||||
role=role,
|
||||
content=msg['content'],
|
||||
)
|
||||
)
|
||||
prompt = entities.Prompt(
|
||||
name=file_name,
|
||||
messages=messages
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
|
||||
from .. import loader
|
||||
from .. import entities
|
||||
from ....provider import entities as llm_entities
|
||||
|
||||
|
||||
@loader.loader_class("normal")
|
||||
class SingleSystemPromptLoader(loader.PromptLoader):
|
||||
"""配置文件中的单条system prompt的prompt加载器
|
||||
"""
|
||||
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
|
||||
for name, cnt in self.ap.provider_cfg.data['prompt'].items():
|
||||
prompt = entities.Prompt(
|
||||
name=name,
|
||||
messages=[
|
||||
llm_entities.Message(
|
||||
role='system',
|
||||
content=cnt
|
||||
)
|
||||
]
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
|
||||
for file in os.listdir("data/prompts"):
|
||||
with open("data/prompts/{}".format(file), "r", encoding="utf-8") as f:
|
||||
file_str = f.read()
|
||||
file_name = file.split(".")[0]
|
||||
prompt = entities.Prompt(
|
||||
name=file_name,
|
||||
messages=[
|
||||
llm_entities.Message(
|
||||
role='system',
|
||||
content=file_str
|
||||
)
|
||||
]
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
@@ -1,56 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ...core import app
|
||||
from . import loader
|
||||
from .loaders import single, scenario
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""Prompt管理器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
loader_inst: loader.PromptLoader
|
||||
|
||||
default_prompt: str = 'default'
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
mode_name = self.ap.provider_cfg.data['prompt-mode']
|
||||
|
||||
loader_class = None
|
||||
|
||||
for loader_cls in loader.preregistered_loaders:
|
||||
if loader_cls.name == mode_name:
|
||||
loader_class = loader_cls
|
||||
break
|
||||
else:
|
||||
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
|
||||
|
||||
self.loader_inst: loader.PromptLoader = loader_class(self.ap)
|
||||
|
||||
await self.loader_inst.initialize()
|
||||
await self.loader_inst.load()
|
||||
|
||||
def get_all_prompts(self) -> list[loader.entities.Prompt]:
|
||||
"""获取所有Prompt
|
||||
"""
|
||||
return self.loader_inst.get_prompts()
|
||||
|
||||
async def get_prompt(self, name: str) -> loader.entities.Prompt:
|
||||
"""获取Prompt
|
||||
"""
|
||||
for prompt in self.get_all_prompts():
|
||||
if prompt.name == name:
|
||||
return prompt
|
||||
|
||||
async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt:
|
||||
"""通过前缀获取Prompt
|
||||
"""
|
||||
for prompt in self.get_all_prompts():
|
||||
if prompt.name.startswith(prefix):
|
||||
return prompt
|
||||
Reference in New Issue
Block a user