refactor: 重构插件系统

This commit is contained in:
RockChinQ
2024-01-29 21:22:27 +08:00
parent b730f17eb6
commit 6cc4688660
53 changed files with 1307 additions and 1993 deletions
+1 -2
View File
@@ -6,7 +6,6 @@ import typing
from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
from ..session import entities as session_entities
class LLMAPIRequester(metaclass=abc.ABCMeta):
"""LLM API请求器
@@ -24,7 +23,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
async def request(
self,
query: core_entities.Query,
conversation: session_entities.Conversation,
conversation: core_entities.Conversation,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求
"""
+7 -31
View File
@@ -10,7 +10,6 @@ import openai.types.chat.chat_completion as chat_completion
from .. import api
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...session import entities as session_entities
class OpenAIChatCompletion(api.LLMAPIRequester):
@@ -43,41 +42,18 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
async def _closure(
self,
req_messages: list[dict],
conversation: session_entities.Conversation,
user_text: str = None,
function_ret: str = None,
conversation: core_entities.Conversation,
) -> llm_entities.Message:
self.client.api_key = conversation.use_model.token_mgr.get_token()
args = self.ap.cfg_mgr.data["completion_api_params"].copy()
args["model"] = conversation.use_model.name
tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation)
# tools = [
# {
# "type": "function",
# "function": {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA",
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"],
# },
# },
# "required": ["location"],
# },
# },
# }
# ]
if tools:
args["tools"] = tools
if conversation.use_model.tool_call_supported:
tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation)
if tools:
args["tools"] = tools
# 设置此次请求中的messages
messages = req_messages
@@ -92,7 +68,7 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
return message
async def request(
self, query: core_entities.Query, conversation: session_entities.Conversation
self, query: core_entities.Query, conversation: core_entities.Conversation
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求"""
+6 -2
View File
@@ -1,9 +1,11 @@
from __future__ import annotations
import typing
import pydantic
from . import api
from . import token
from . import token, tokenizer
class LLMModelInfo(pydantic.BaseModel):
@@ -17,7 +19,9 @@ class LLMModelInfo(pydantic.BaseModel):
requester: api.LLMAPIRequester
function_call_supported: typing.Optional[bool] = False
tokenizer: 'tokenizer.LLMTokenizer'
tool_call_supported: typing.Optional[bool] = False
class Config:
arbitrary_types_allowed = True
+13 -9
View File
@@ -5,6 +5,7 @@ from ...core import app
from .apis import chatcmpl
from . import token
from .tokenizers import tiktoken
class ModelManager:
@@ -17,25 +18,28 @@ class ModelManager:
self.ap = ap
self.model_list = []
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
"""通过名称获取模型
"""
for model in self.model_list:
if model.name == name:
return model
raise ValueError(f"Model {name} not found")
async def initialize(self):
openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap)
await openai_chat_completion.initialize()
openai_token_mgr = token.TokenManager(self.ap, list(self.ap.cfg_mgr.data['openai_config']['api_key'].values()))
tiktoken_tokenizer = tiktoken.Tiktoken(self.ap)
self.model_list.append(
entities.LLMModelInfo(
name="gpt-3.5-turbo",
provider="openai",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
function_call_supported=True
tool_call_supported=True,
tokenizer=tiktoken_tokenizer
)
)
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
"""通过名称获取模型
"""
for model in self.model_list:
if model.name == name:
return model
raise ValueError(f"Model {name} not found")
+29
View File
@@ -0,0 +1,29 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from .. import entities as llm_entities
from . import entities
class LLMTokenizer(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
"""初始化分词器
"""
pass
@abc.abstractmethod
async def count_token(
self,
messages: list[llm_entities.Message],
model: entities.LLMModelInfo
) -> int:
pass
@@ -0,0 +1,28 @@
from __future__ import annotations
import tiktoken
from .. import tokenizer
from ... import entities as llm_entities
from .. import entities
class Tiktoken(tokenizer.LLMTokenizer):
async def count_token(
self,
messages: list[llm_entities.Message],
model: entities.LLMModelInfo
) -> int:
try:
encoding = tiktoken.encoding_for_model(model.name)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for message in messages:
num_tokens += len(encoding.encode(message.role))
num_tokens += len(encoding.encode(message.content))
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens