mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-11 00:06:04 +00:00
refactor: 重构openai包基础组件架构
This commit is contained in:
31
pkg/openai/entities.py
Normal file
31
pkg/openai/entities.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import enum
|
||||
import pydantic
|
||||
|
||||
|
||||
class MessageRole(enum.Enum):
|
||||
|
||||
SYSTEM = 'system'
|
||||
|
||||
USER = 'user'
|
||||
|
||||
ASSISTANT = 'assistant'
|
||||
|
||||
FUNCTION = 'function'
|
||||
|
||||
|
||||
class FunctionCall(pydantic.BaseModel):
|
||||
name: str
|
||||
|
||||
args: dict[str, typing.Any]
|
||||
|
||||
|
||||
class Message(pydantic.BaseModel):
|
||||
|
||||
role: MessageRole
|
||||
|
||||
content: typing.Optional[str] = None
|
||||
|
||||
function_call: typing.Optional[FunctionCall] = None
|
||||
0
pkg/openai/requester/__init__.py
Normal file
0
pkg/openai/requester/__init__.py
Normal file
31
pkg/openai/requester/api.py
Normal file
31
pkg/openai/requester/api.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
from ..session import entities as session_entities
|
||||
|
||||
class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""LLM API请求器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def request(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
conversation: session_entities.Conversation,
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""请求
|
||||
"""
|
||||
raise NotImplementedError
|
||||
0
pkg/openai/requester/apis/__init__.py
Normal file
0
pkg/openai/requester/apis/__init__.py
Normal file
32
pkg/openai/requester/apis/chatcmpl.py
Normal file
32
pkg/openai/requester/apis/chatcmpl.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
import openai
|
||||
|
||||
from .. import api
|
||||
from ....core import entities as core_entities
|
||||
from ... import entities as llm_entities
|
||||
from ...session import entities as session_entities
|
||||
|
||||
|
||||
class OpenAIChatCompletion(api.LLMAPIRequester):
|
||||
|
||||
client: openai.Client
|
||||
|
||||
async def initialize(self):
|
||||
self.client = openai.Client(
|
||||
base_url=self.ap.cfg_mgr.data['openai_config']['reverse_proxy'],
|
||||
timeout=self.ap.cfg_mgr.data['process_message_timeout']
|
||||
)
|
||||
|
||||
async def request(self, query: core_entities.Query, conversation: session_entities.Conversation) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""请求
|
||||
"""
|
||||
await asyncio.sleep(10)
|
||||
|
||||
yield llm_entities.Message(
|
||||
role=llm_entities.MessageRole.ASSISTANT,
|
||||
content="hello"
|
||||
)
|
||||
23
pkg/openai/requester/entities.py
Normal file
23
pkg/openai/requester/entities.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
from . import api
|
||||
from . import token
|
||||
|
||||
|
||||
class LLMModelInfo(pydantic.BaseModel):
|
||||
"""模型"""
|
||||
|
||||
name: str
|
||||
|
||||
provider: str
|
||||
|
||||
token_mgr: token.TokenManager
|
||||
|
||||
requester: api.LLMAPIRequester
|
||||
|
||||
function_call_supported: typing.Optional[bool] = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
40
pkg/openai/requester/modelmgr.py
Normal file
40
pkg/openai/requester/modelmgr.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import entities
|
||||
from ...core import app
|
||||
|
||||
from .apis import chatcmpl
|
||||
from . import token
|
||||
|
||||
|
||||
class ModelManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
model_list: list[entities.LLMModelInfo]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.model_list = []
|
||||
|
||||
async def initialize(self):
|
||||
openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap)
|
||||
openai_token_mgr = token.TokenManager(self.ap, self.ap.cfg_mgr.data['openai_config']['api_key'].values())
|
||||
|
||||
self.model_list.append(
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-3.5-turbo",
|
||||
provider="openai",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
function_call_supported=True
|
||||
)
|
||||
)
|
||||
|
||||
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
|
||||
"""通过名称获取模型
|
||||
"""
|
||||
for model in self.model_list:
|
||||
if model.name == name:
|
||||
return model
|
||||
raise ValueError(f"Model {name} not found")
|
||||
25
pkg/openai/requester/token.py
Normal file
25
pkg/openai/requester/token.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class TokenManager():
|
||||
|
||||
provider: str
|
||||
|
||||
tokens: list[str]
|
||||
|
||||
using_token_index: typing.Optional[int] = 0
|
||||
|
||||
def __init__(self, provider: str, tokens: list[str]):
|
||||
self.provider = provider
|
||||
self.tokens = tokens
|
||||
self.using_token_index = 0
|
||||
|
||||
def get_token(self) -> str:
|
||||
return self.tokens[self.using_token_index]
|
||||
|
||||
def next_token(self):
|
||||
self.using_token_index = (self.using_token_index + 1) % len(self.tokens)
|
||||
0
pkg/openai/session/__init__.py
Normal file
0
pkg/openai/session/__init__.py
Normal file
50
pkg/openai/session/entities.py
Normal file
50
pkg/openai/session/entities.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
from ..sysprompt import entities as sysprompt_entities
|
||||
from .. import entities as llm_entities
|
||||
from ..requester import entities
|
||||
from ...core import entities as core_entities
|
||||
|
||||
|
||||
class Conversation(pydantic.BaseModel):
|
||||
"""对话"""
|
||||
|
||||
prompt: sysprompt_entities.Prompt
|
||||
|
||||
messages: list[llm_entities.Message]
|
||||
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
use_model: entities.LLMModelInfo
|
||||
|
||||
|
||||
class Session(pydantic.BaseModel):
|
||||
"""会话"""
|
||||
launcher_type: core_entities.LauncherTypes
|
||||
|
||||
launcher_id: int
|
||||
|
||||
sender_id: typing.Optional[int] = 0
|
||||
|
||||
use_prompt_name: typing.Optional[str] = 'default'
|
||||
|
||||
using_conversation: typing.Optional[Conversation] = None
|
||||
|
||||
conversations: typing.Optional[list[Conversation]] = []
|
||||
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
semaphore: typing.Optional[asyncio.Semaphore] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
50
pkg/openai/session/sessionmgr.py
Normal file
50
pkg/openai/session/sessionmgr.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
|
||||
|
||||
class SessionManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
session_list: list[entities.Session]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.session_list = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def get_session(self, query: core_entities.Query) -> entities.Session:
|
||||
"""获取会话
|
||||
"""
|
||||
for session in self.session_list:
|
||||
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
|
||||
return session
|
||||
|
||||
session = entities.Session(
|
||||
launcher_type=query.launcher_type,
|
||||
launcher_id=query.launcher_id,
|
||||
semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000)
|
||||
)
|
||||
self.session_list.append(session)
|
||||
return session
|
||||
|
||||
async def get_conversation(self, session: entities.Session) -> entities.Conversation:
|
||||
if not session.conversations:
|
||||
session.conversations = []
|
||||
|
||||
if session.using_conversation is None:
|
||||
conversation = entities.Conversation(
|
||||
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
|
||||
messages=[],
|
||||
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']),
|
||||
)
|
||||
session.conversations.append(conversation)
|
||||
session.using_conversation = conversation
|
||||
|
||||
return session.using_conversation
|
||||
0
pkg/openai/sysprompt/__init__.py
Normal file
0
pkg/openai/sysprompt/__init__.py
Normal file
14
pkg/openai/sysprompt/entities.py
Normal file
14
pkg/openai/sysprompt/entities.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import pydantic
|
||||
|
||||
from ...openai import entities
|
||||
|
||||
|
||||
class Prompt(pydantic.BaseModel):
|
||||
"""供AI使用的Prompt"""
|
||||
|
||||
name: str
|
||||
|
||||
messages: list[entities.Message]
|
||||
32
pkg/openai/sysprompt/loader.py
Normal file
32
pkg/openai/sysprompt/loader.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
class PromptLoader(metaclass=abc.ABCMeta):
|
||||
"""Prompt加载器抽象类
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
prompts: list[entities.Prompt]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.prompts = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_prompts(self) -> list[entities.Prompt]:
|
||||
"""获取Prompt列表
|
||||
"""
|
||||
return self.prompts
|
||||
0
pkg/openai/sysprompt/loaders/__init__.py
Normal file
0
pkg/openai/sysprompt/loaders/__init__.py
Normal file
43
pkg/openai/sysprompt/loaders/scenario.py
Normal file
43
pkg/openai/sysprompt/loaders/scenario.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from .. import loader
|
||||
from .. import entities
|
||||
from ....openai import entities as llm_entities
|
||||
|
||||
|
||||
class ScenarioPromptLoader(loader.PromptLoader):
|
||||
"""加载scenario目录下的json"""
|
||||
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
for file in os.listdir("scenarios"):
|
||||
with open("scenarios/{}".format(file), "r", encoding="utf-8") as f:
|
||||
file_str = f.read()
|
||||
file_name = file.split(".")[0]
|
||||
file_json = json.loads(file_str)
|
||||
messages = []
|
||||
for msg in file_json["prompt"]:
|
||||
role = llm_entities.MessageRole.SYSTEM
|
||||
if "role" in msg:
|
||||
if msg["role"] == "user":
|
||||
role = llm_entities.MessageRole.USER
|
||||
elif msg["role"] == "system":
|
||||
role = llm_entities.MessageRole.SYSTEM
|
||||
elif msg["role"] == "function":
|
||||
role = llm_entities.MessageRole.FUNCTION
|
||||
messages.append(
|
||||
llm_entities.Message(
|
||||
role=role,
|
||||
content=msg['content'],
|
||||
)
|
||||
)
|
||||
prompt = entities.Prompt(
|
||||
name=file_name,
|
||||
messages=messages
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
|
||||
42
pkg/openai/sysprompt/loaders/single.py
Normal file
42
pkg/openai/sysprompt/loaders/single.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
|
||||
from .. import loader
|
||||
from .. import entities
|
||||
from ....openai import entities as llm_entities
|
||||
|
||||
|
||||
class SingleSystemPromptLoader(loader.PromptLoader):
|
||||
"""配置文件中的单条system prompt的prompt加载器
|
||||
"""
|
||||
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
|
||||
for name, cnt in self.ap.cfg_mgr.data['default_prompt'].items():
|
||||
prompt = entities.Prompt(
|
||||
name=name,
|
||||
messages=[
|
||||
llm_entities.Message(
|
||||
role=llm_entities.MessageRole.SYSTEM,
|
||||
content=cnt
|
||||
)
|
||||
]
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
|
||||
for file in os.listdir("prompts"):
|
||||
with open("prompts/{}".format(file), "r", encoding="utf-8") as f:
|
||||
file_str = f.read()
|
||||
file_name = file.split(".")[0]
|
||||
prompt = entities.Prompt(
|
||||
name=file_name,
|
||||
messages=[
|
||||
llm_entities.Message(
|
||||
role=llm_entities.MessageRole.SYSTEM,
|
||||
content=file_str
|
||||
)
|
||||
]
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
43
pkg/openai/sysprompt/sysprompt.py
Normal file
43
pkg/openai/sysprompt/sysprompt.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ...core import app
|
||||
from . import loader
|
||||
from .loaders import single, scenario
|
||||
|
||||
|
||||
class PromptManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
loader_inst: loader.PromptLoader
|
||||
|
||||
default_prompt: str = 'default'
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
loader_map = {
|
||||
"normal": single.SingleSystemPromptLoader,
|
||||
"full_scenario": scenario.ScenarioPromptLoader
|
||||
}
|
||||
|
||||
loader_cls = loader_map[self.ap.cfg_mgr.data['preset_mode']]
|
||||
|
||||
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
|
||||
|
||||
await self.loader_inst.initialize()
|
||||
await self.loader_inst.load()
|
||||
|
||||
def get_all_prompts(self) -> list[loader.entities.Prompt]:
|
||||
"""获取所有Prompt
|
||||
"""
|
||||
return self.loader_inst.get_prompts()
|
||||
|
||||
async def get_prompt(self, name: str) -> loader.entities.Prompt:
|
||||
"""获取Prompt
|
||||
"""
|
||||
for prompt in self.get_all_prompts():
|
||||
if prompt.name == name:
|
||||
return prompt
|
||||
Reference in New Issue
Block a user