refactor: move prompt mgm to pipeline

This commit is contained in:
Junyan Qin
2025-04-03 17:06:01 +08:00
parent 913e43d84c
commit fb18278bdc
16 changed files with 36 additions and 285 deletions

View File

@@ -4,6 +4,8 @@ import typing
import enum
import pydantic.v1 as pydantic
from pkg.provider import entities
from ..platform.types import message as platform_message
@@ -124,3 +126,13 @@ class Message(pydantic.BaseModel):
mc.insert(0, platform_message.Plain(prefix_text))
return platform_message.MessageChain(mc)
class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt"""
name: str
"""名称"""
messages: list[entities.Message]
"""消息列表"""

View File

@@ -4,6 +4,7 @@ import asyncio
from ...core import app, entities as core_entities
from ...plugin import context as plugin_context
from ...provider import entities as provider_entities
class SessionManager:
@@ -41,15 +42,26 @@ class SessionManager:
self.session_list.append(session)
return session
async def get_conversation(self, query: core_entities.Query, session: core_entities.Session) -> core_entities.Conversation:
async def get_conversation(self, query: core_entities.Query, session: core_entities.Session, pipeline_config: dict) -> core_entities.Conversation:
"""获取对话或创建对话"""
if not session.conversations:
session.conversations = []
# set prompt
prompt_messages = []
for prompt_message in pipeline_config['ai']['local-agent']['prompt']:
prompt_messages.append(provider_entities.Message(**prompt_message))
prompt = provider_entities.Prompt(
name="default",
messages=prompt_messages,
)
if session.using_conversation is None:
conversation = core_entities.Conversation(
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
prompt=prompt,
messages=[],
use_llm_model=await self.ap.model_mgr.get_model_by_uuid(
query.pipeline_config['ai']['local-agent']['model']

View File

@@ -1,16 +0,0 @@
from __future__ import annotations
import typing
import pydantic.v1 as pydantic
from ...provider import entities
class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt"""
name: str
"""名称"""
messages: list[entities.Message]
"""消息列表"""

View File

@@ -1,46 +0,0 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from . import entities
preregistered_loaders: list[typing.Type[PromptLoader]] = []
def loader_class(name: str):
def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]:
cls.name = name
preregistered_loaders.append(cls)
return cls
return decorator
class PromptLoader(metaclass=abc.ABCMeta):
"""Prompt加载器抽象类
"""
name: str
ap: app.Application
prompts: list[entities.Prompt]
def __init__(self, ap: app.Application):
self.ap = ap
self.prompts = []
async def initialize(self):
pass
@abc.abstractmethod
async def load(self):
"""加载Prompt存放到prompts列表中
"""
raise NotImplementedError
def get_prompts(self) -> list[entities.Prompt]:
"""获取Prompt列表
"""
return self.prompts

View File

@@ -1,39 +0,0 @@
from __future__ import annotations
import json
import os
from .. import loader
from .. import entities
from ....provider import entities as llm_entities
@loader.loader_class("full-scenario")
class ScenarioPromptLoader(loader.PromptLoader):
"""加载scenario目录下的json"""
async def load(self):
"""加载Prompt
"""
for file in os.listdir("data/scenario"):
with open("data/scenario/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
file_json = json.loads(file_str)
messages = []
for msg in file_json["prompt"]:
role = 'system'
if "role" in msg:
role = msg['role']
messages.append(
llm_entities.Message(
role=role,
content=msg['content'],
)
)
prompt = entities.Prompt(
name=file_name,
messages=messages
)
self.prompts.append(prompt)

View File

@@ -1,43 +0,0 @@
from __future__ import annotations
import os
from .. import loader
from .. import entities
from ....provider import entities as llm_entities
@loader.loader_class("normal")
class SingleSystemPromptLoader(loader.PromptLoader):
"""配置文件中的单条system prompt的prompt加载器
"""
async def load(self):
"""加载Prompt
"""
for name, cnt in self.ap.provider_cfg.data['prompt'].items():
prompt = entities.Prompt(
name=name,
messages=[
llm_entities.Message(
role='system',
content=cnt
)
]
)
self.prompts.append(prompt)
for file in os.listdir("data/prompts"):
with open("data/prompts/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
prompt = entities.Prompt(
name=file_name,
messages=[
llm_entities.Message(
role='system',
content=file_str
)
]
)
self.prompts.append(prompt)

View File

@@ -1,56 +0,0 @@
from __future__ import annotations
from ...core import app
from . import loader
from .loaders import single, scenario
class PromptManager:
"""Prompt管理器
"""
ap: app.Application
loader_inst: loader.PromptLoader
default_prompt: str = 'default'
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
mode_name = self.ap.provider_cfg.data['prompt-mode']
loader_class = None
for loader_cls in loader.preregistered_loaders:
if loader_cls.name == mode_name:
loader_class = loader_cls
break
else:
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
self.loader_inst: loader.PromptLoader = loader_class(self.ap)
await self.loader_inst.initialize()
await self.loader_inst.load()
def get_all_prompts(self) -> list[loader.entities.Prompt]:
"""获取所有Prompt
"""
return self.loader_inst.get_prompts()
async def get_prompt(self, name: str) -> loader.entities.Prompt:
"""获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name == name:
return prompt
async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt:
"""通过前缀获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name.startswith(prefix):
return prompt