mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-11 00:06:04 +00:00
refactor: move prompt mgm to pipeline
This commit is contained in:
@@ -4,6 +4,8 @@ import typing
|
||||
import enum
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from pkg.provider import entities
|
||||
|
||||
|
||||
from ..platform.types import message as platform_message
|
||||
|
||||
@@ -124,3 +126,13 @@ class Message(pydantic.BaseModel):
|
||||
mc.insert(0, platform_message.Plain(prefix_text))
|
||||
|
||||
return platform_message.MessageChain(mc)
|
||||
|
||||
|
||||
class Prompt(pydantic.BaseModel):
|
||||
"""供AI使用的Prompt"""
|
||||
|
||||
name: str
|
||||
"""名称"""
|
||||
|
||||
messages: list[entities.Message]
|
||||
"""消息列表"""
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from ...plugin import context as plugin_context
|
||||
from ...provider import entities as provider_entities
|
||||
|
||||
|
||||
class SessionManager:
|
||||
@@ -41,15 +42,26 @@ class SessionManager:
|
||||
self.session_list.append(session)
|
||||
return session
|
||||
|
||||
async def get_conversation(self, query: core_entities.Query, session: core_entities.Session) -> core_entities.Conversation:
|
||||
async def get_conversation(self, query: core_entities.Query, session: core_entities.Session, pipeline_config: dict) -> core_entities.Conversation:
|
||||
"""获取对话或创建对话"""
|
||||
|
||||
if not session.conversations:
|
||||
session.conversations = []
|
||||
|
||||
# set prompt
|
||||
prompt_messages = []
|
||||
|
||||
for prompt_message in pipeline_config['ai']['local-agent']['prompt']:
|
||||
prompt_messages.append(provider_entities.Message(**prompt_message))
|
||||
|
||||
prompt = provider_entities.Prompt(
|
||||
name="default",
|
||||
messages=prompt_messages,
|
||||
)
|
||||
|
||||
if session.using_conversation is None:
|
||||
conversation = core_entities.Conversation(
|
||||
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
|
||||
prompt=prompt,
|
||||
messages=[],
|
||||
use_llm_model=await self.ap.model_mgr.get_model_by_uuid(
|
||||
query.pipeline_config['ai']['local-agent']['model']
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from ...provider import entities
|
||||
|
||||
|
||||
class Prompt(pydantic.BaseModel):
|
||||
"""供AI使用的Prompt"""
|
||||
|
||||
name: str
|
||||
"""名称"""
|
||||
|
||||
messages: list[entities.Message]
|
||||
"""消息列表"""
|
||||
@@ -1,46 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
preregistered_loaders: list[typing.Type[PromptLoader]] = []
|
||||
|
||||
def loader_class(name: str):
|
||||
|
||||
def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]:
|
||||
cls.name = name
|
||||
preregistered_loaders.append(cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class PromptLoader(metaclass=abc.ABCMeta):
|
||||
"""Prompt加载器抽象类
|
||||
"""
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
prompts: list[entities.Prompt]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.prompts = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def load(self):
|
||||
"""加载Prompt,存放到prompts列表中
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_prompts(self) -> list[entities.Prompt]:
|
||||
"""获取Prompt列表
|
||||
"""
|
||||
return self.prompts
|
||||
@@ -1,39 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from .. import loader
|
||||
from .. import entities
|
||||
from ....provider import entities as llm_entities
|
||||
|
||||
|
||||
@loader.loader_class("full-scenario")
|
||||
class ScenarioPromptLoader(loader.PromptLoader):
|
||||
"""加载scenario目录下的json"""
|
||||
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
for file in os.listdir("data/scenario"):
|
||||
with open("data/scenario/{}".format(file), "r", encoding="utf-8") as f:
|
||||
file_str = f.read()
|
||||
file_name = file.split(".")[0]
|
||||
file_json = json.loads(file_str)
|
||||
messages = []
|
||||
for msg in file_json["prompt"]:
|
||||
role = 'system'
|
||||
if "role" in msg:
|
||||
role = msg['role']
|
||||
messages.append(
|
||||
llm_entities.Message(
|
||||
role=role,
|
||||
content=msg['content'],
|
||||
)
|
||||
)
|
||||
prompt = entities.Prompt(
|
||||
name=file_name,
|
||||
messages=messages
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
|
||||
from .. import loader
|
||||
from .. import entities
|
||||
from ....provider import entities as llm_entities
|
||||
|
||||
|
||||
@loader.loader_class("normal")
|
||||
class SingleSystemPromptLoader(loader.PromptLoader):
|
||||
"""配置文件中的单条system prompt的prompt加载器
|
||||
"""
|
||||
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
|
||||
for name, cnt in self.ap.provider_cfg.data['prompt'].items():
|
||||
prompt = entities.Prompt(
|
||||
name=name,
|
||||
messages=[
|
||||
llm_entities.Message(
|
||||
role='system',
|
||||
content=cnt
|
||||
)
|
||||
]
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
|
||||
for file in os.listdir("data/prompts"):
|
||||
with open("data/prompts/{}".format(file), "r", encoding="utf-8") as f:
|
||||
file_str = f.read()
|
||||
file_name = file.split(".")[0]
|
||||
prompt = entities.Prompt(
|
||||
name=file_name,
|
||||
messages=[
|
||||
llm_entities.Message(
|
||||
role='system',
|
||||
content=file_str
|
||||
)
|
||||
]
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
@@ -1,56 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ...core import app
|
||||
from . import loader
|
||||
from .loaders import single, scenario
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""Prompt管理器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
loader_inst: loader.PromptLoader
|
||||
|
||||
default_prompt: str = 'default'
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
mode_name = self.ap.provider_cfg.data['prompt-mode']
|
||||
|
||||
loader_class = None
|
||||
|
||||
for loader_cls in loader.preregistered_loaders:
|
||||
if loader_cls.name == mode_name:
|
||||
loader_class = loader_cls
|
||||
break
|
||||
else:
|
||||
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
|
||||
|
||||
self.loader_inst: loader.PromptLoader = loader_class(self.ap)
|
||||
|
||||
await self.loader_inst.initialize()
|
||||
await self.loader_inst.load()
|
||||
|
||||
def get_all_prompts(self) -> list[loader.entities.Prompt]:
|
||||
"""获取所有Prompt
|
||||
"""
|
||||
return self.loader_inst.get_prompts()
|
||||
|
||||
async def get_prompt(self, name: str) -> loader.entities.Prompt:
|
||||
"""获取Prompt
|
||||
"""
|
||||
for prompt in self.get_all_prompts():
|
||||
if prompt.name == name:
|
||||
return prompt
|
||||
|
||||
async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt:
|
||||
"""通过前缀获取Prompt
|
||||
"""
|
||||
for prompt in self.get_all_prompts():
|
||||
if prompt.name.startswith(prefix):
|
||||
return prompt
|
||||
Reference in New Issue
Block a user