refactor: 重构openai包基础组件架构

This commit is contained in:
RockChinQ
2024-01-27 00:06:38 +08:00
parent 411034902a
commit 850a4eeb7c
35 changed files with 779 additions and 59 deletions
View File
+14
View File
@@ -0,0 +1,14 @@
from __future__ import annotations
import typing
import pydantic
from ...openai import entities
class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt"""
name: str
messages: list[entities.Message]
+32
View File
@@ -0,0 +1,32 @@
from __future__ import annotations
import abc
from ...core import app
from . import entities
class PromptLoader(metaclass=abc.ABCMeta):
"""Prompt加载器抽象类
"""
ap: app.Application
prompts: list[entities.Prompt]
def __init__(self, ap: app.Application):
self.ap = ap
self.prompts = []
async def initialize(self):
pass
@abc.abstractmethod
async def load(self):
"""加载Prompt
"""
raise NotImplementedError
def get_prompts(self) -> list[entities.Prompt]:
"""获取Prompt列表
"""
return self.prompts
+43
View File
@@ -0,0 +1,43 @@
from __future__ import annotations
import json
import os
from .. import loader
from .. import entities
from ....openai import entities as llm_entities
class ScenarioPromptLoader(loader.PromptLoader):
"""加载scenario目录下的json"""
async def load(self):
"""加载Prompt
"""
for file in os.listdir("scenarios"):
with open("scenarios/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
file_json = json.loads(file_str)
messages = []
for msg in file_json["prompt"]:
role = llm_entities.MessageRole.SYSTEM
if "role" in msg:
if msg["role"] == "user":
role = llm_entities.MessageRole.USER
elif msg["role"] == "system":
role = llm_entities.MessageRole.SYSTEM
elif msg["role"] == "function":
role = llm_entities.MessageRole.FUNCTION
messages.append(
llm_entities.Message(
role=role,
content=msg['content'],
)
)
prompt = entities.Prompt(
name=file_name,
messages=messages
)
self.prompts.append(prompt)
+42
View File
@@ -0,0 +1,42 @@
from __future__ import annotations
import os
from .. import loader
from .. import entities
from ....openai import entities as llm_entities
class SingleSystemPromptLoader(loader.PromptLoader):
"""配置文件中的单条system prompt的prompt加载器
"""
async def load(self):
"""加载Prompt
"""
for name, cnt in self.ap.cfg_mgr.data['default_prompt'].items():
prompt = entities.Prompt(
name=name,
messages=[
llm_entities.Message(
role=llm_entities.MessageRole.SYSTEM,
content=cnt
)
]
)
self.prompts.append(prompt)
for file in os.listdir("prompts"):
with open("prompts/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
prompt = entities.Prompt(
name=file_name,
messages=[
llm_entities.Message(
role=llm_entities.MessageRole.SYSTEM,
content=file_str
)
]
)
self.prompts.append(prompt)
+43
View File
@@ -0,0 +1,43 @@
from __future__ import annotations
from ...core import app
from . import loader
from .loaders import single, scenario
class PromptManager:
ap: app.Application
loader_inst: loader.PromptLoader
default_prompt: str = 'default'
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
loader_map = {
"normal": single.SingleSystemPromptLoader,
"full_scenario": scenario.ScenarioPromptLoader
}
loader_cls = loader_map[self.ap.cfg_mgr.data['preset_mode']]
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
await self.loader_inst.initialize()
await self.loader_inst.load()
def get_all_prompts(self) -> list[loader.entities.Prompt]:
"""获取所有Prompt
"""
return self.loader_inst.get_prompts()
async def get_prompt(self, name: str) -> loader.entities.Prompt:
"""获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name == name:
return prompt