chore: 修改包名

This commit is contained in:
RockChinQ
2024-01-28 19:20:10 +08:00
parent 698782c537
commit b730f17eb6
45 changed files with 27 additions and 27 deletions
View File
+14
View File
@@ -0,0 +1,14 @@
from __future__ import annotations
import typing
import pydantic
from ...provider import entities
class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt"""
name: str
messages: list[entities.Message]
+32
View File
@@ -0,0 +1,32 @@
from __future__ import annotations
import abc
from ...core import app
from . import entities
class PromptLoader(metaclass=abc.ABCMeta):
"""Prompt加载器抽象类
"""
ap: app.Application
prompts: list[entities.Prompt]
def __init__(self, ap: app.Application):
self.ap = ap
self.prompts = []
async def initialize(self):
pass
@abc.abstractmethod
async def load(self):
"""加载Prompt
"""
raise NotImplementedError
def get_prompts(self) -> list[entities.Prompt]:
"""获取Prompt列表
"""
return self.prompts
@@ -0,0 +1,38 @@
from __future__ import annotations
import json
import os
from .. import loader
from .. import entities
from ....provider import entities as llm_entities
class ScenarioPromptLoader(loader.PromptLoader):
"""加载scenario目录下的json"""
async def load(self):
"""加载Prompt
"""
for file in os.listdir("scenarios"):
with open("scenarios/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
file_json = json.loads(file_str)
messages = []
for msg in file_json["prompt"]:
role = 'system'
if "role" in msg:
role = msg['role']
messages.append(
llm_entities.Message(
role=role,
content=msg['content'],
)
)
prompt = entities.Prompt(
name=file_name,
messages=messages
)
self.prompts.append(prompt)
+42
View File
@@ -0,0 +1,42 @@
from __future__ import annotations
import os
from .. import loader
from .. import entities
from ....provider import entities as llm_entities
class SingleSystemPromptLoader(loader.PromptLoader):
"""配置文件中的单条system prompt的prompt加载器
"""
async def load(self):
"""加载Prompt
"""
for name, cnt in self.ap.cfg_mgr.data['default_prompt'].items():
prompt = entities.Prompt(
name=name,
messages=[
llm_entities.Message(
role='system',
content=cnt
)
]
)
self.prompts.append(prompt)
for file in os.listdir("prompts"):
with open("prompts/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
prompt = entities.Prompt(
name=file_name,
messages=[
llm_entities.Message(
role='system',
content=file_str
)
]
)
self.prompts.append(prompt)
+50
View File
@@ -0,0 +1,50 @@
from __future__ import annotations
from ...core import app
from . import loader
from .loaders import single, scenario
class PromptManager:
ap: app.Application
loader_inst: loader.PromptLoader
default_prompt: str = 'default'
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
loader_map = {
"normal": single.SingleSystemPromptLoader,
"full_scenario": scenario.ScenarioPromptLoader
}
loader_cls = loader_map[self.ap.cfg_mgr.data['preset_mode']]
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
await self.loader_inst.initialize()
await self.loader_inst.load()
def get_all_prompts(self) -> list[loader.entities.Prompt]:
"""获取所有Prompt
"""
return self.loader_inst.get_prompts()
async def get_prompt(self, name: str) -> loader.entities.Prompt:
"""获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name == name:
return prompt
async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt:
"""通过前缀获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name.startswith(prefix):
return prompt