refactor: 重构插件系统

This commit is contained in:
RockChinQ
2024-01-29 21:22:27 +08:00
parent b730f17eb6
commit 6cc4688660
53 changed files with 1307 additions and 1993 deletions
-232
View File
@@ -1,232 +0,0 @@
import json
import logging
import openai
from openai.types.chat import chat_completion_message
from .model import RequestBase
from .. import funcmgr
from ...plugin import host
from ...utils import context
class ChatCompletionRequest(RequestBase):
"""调用ChatCompletion接口的请求类。
此类保证每一次返回的角色为assistant的信息的finish_reason一定为stop。
若有函数调用响应,本类的返回瀑布是:函数调用请求->函数调用结果->...->assistant的信息->stop。
"""
model: str
messages: list[dict[str, str]]
kwargs: dict
stopped: bool = False
pending_func_call: chat_completion_message.FunctionCall = None
pending_msg: str
def flush_pending_msg(self):
self.append_message(
role="assistant",
content=self.pending_msg
)
self.pending_msg = ""
def append_message(self, role: str, content: str, name: str=None, function_call: dict=None):
msg = {
"role": role,
"content": content
}
if name is not None:
msg['name'] = name
if function_call is not None:
msg['function_call'] = function_call
self.messages.append(msg)
def __init__(
self,
client: openai.Client,
model: str,
messages: list[dict[str, str]],
**kwargs
):
self.client = client
self.model = model
self.messages = messages.copy()
self.kwargs = kwargs
self.req_func = self.client.chat.completions.create
self.pending_func_call = None
self.stopped = False
self.pending_msg = ""
def __iter__(self):
return self
def __next__(self) -> dict:
if self.stopped:
raise StopIteration()
if self.pending_func_call is None: # 没有待处理的函数调用请求
args = {
"model": self.model,
"messages": self.messages,
}
funcs = funcmgr.get_func_schema_list()
if len(funcs) > 0:
args['functions'] = funcs
# 拼接kwargs
args = {**args, **self.kwargs}
from openai.types.chat import chat_completion
resp: chat_completion.ChatCompletion = self._req(**args)
choice0 = resp.choices[0]
# 如果不是函数调用,且finish_reason为stop,则停止迭代
if choice0.finish_reason == 'stop': # and choice0["finish_reason"] == "stop"
self.stopped = True
if hasattr(choice0.message, 'function_call') and choice0.message.function_call is not None:
self.pending_func_call = choice0.message.function_call
self.append_message(
role="assistant",
content=choice0.message.content,
function_call=choice0.message.function_call
)
return {
"id": resp.id,
"choices": [
{
"index": choice0.index,
"message": {
"role": "assistant",
"type": "function_call",
"content": choice0.message.content,
"function_call": {
"name": choice0.message.function_call.name,
"arguments": choice0.message.function_call.arguments
}
},
"finish_reason": "function_call"
}
],
"usage": {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens
}
}
else:
# self.pending_msg += choice0['message']['content']
# 普通回复一定处于最后方,故不用再追加进内部messages
return {
"id": resp.id,
"choices": [
{
"index": choice0.index,
"message": {
"role": "assistant",
"type": "text",
"content": choice0.message.content
},
"finish_reason": choice0.finish_reason
}
],
"usage": {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens
}
}
else: # 处理函数调用请求
cp_pending_func_call = self.pending_func_call.copy()
self.pending_func_call = None
func_name = cp_pending_func_call.name
arguments = {}
try:
try:
arguments = json.loads(cp_pending_func_call.arguments)
# 若不是json格式的异常处理
except json.decoder.JSONDecodeError:
# 获取函数的参数列表
func_schema = funcmgr.get_func_schema(func_name)
arguments = {
func_schema['parameters']['required'][0]: cp_pending_func_call.arguments
}
logging.info("执行函数调用: name={}, arguments={}".format(func_name, arguments))
# 执行函数调用
ret = ""
try:
ret = funcmgr.execute_function(func_name, arguments)
logging.info("函数执行完成。")
except Exception as e:
ret = "error: execute function failed: {}".format(str(e))
logging.error("函数执行失败: {}".format(str(e)))
# 上报数据
plugin_info = host.get_plugin_info_for_audit(func_name.split('-')[0])
audit_func_name = func_name.split('-')[1]
audit_func_desc = funcmgr.get_func_schema(func_name)['description']
context.get_center_v2_api().usage.post_function_record(
plugin=plugin_info,
function_name=audit_func_name,
function_description=audit_func_desc,
)
self.append_message(
role="function",
content=json.dumps(ret, ensure_ascii=False),
name=func_name
)
return {
"id": -1,
"choices": [
{
"index": -1,
"message": {
"role": "function",
"type": "function_return",
"function_name": func_name,
"content": json.dumps(ret, ensure_ascii=False)
},
"finish_reason": "function_return"
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
except funcmgr.ContentFunctionNotFoundError:
raise Exception("没有找到函数: {}".format(func_name))
-100
View File
@@ -1,100 +0,0 @@
import openai
from openai.types import completion, completion_choice
from . import model
class CompletionRequest(model.RequestBase):
"""调用Completion接口的请求类。
调用方可以一直next completion直到finish_reason为stop。
"""
model: str
prompt: str
kwargs: dict
stopped: bool = False
def __init__(
self,
client: openai.Client,
model: str,
messages: list[dict[str, str]],
**kwargs
):
self.client = client
self.model = model
self.prompt = ""
for message in messages:
self.prompt += message["role"] + ": " + message["content"] + "\n"
self.prompt += "assistant: "
self.kwargs = kwargs
self.req_func = self.client.completions.create
def __iter__(self):
return self
def __next__(self) -> dict:
"""调用Completion接口,返回生成的文本
{
"id": "id",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"type": "text",
"content": "message"
},
"finish_reason": "reason"
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}
"""
if self.stopped:
raise StopIteration()
resp: completion.Completion = self._req(
model=self.model,
prompt=self.prompt,
**self.kwargs
)
if resp.choices[0].finish_reason == "stop":
self.stopped = True
choice0: completion_choice.CompletionChoice = resp.choices[0]
self.prompt += choice0.text
return {
"id": resp.id,
"choices": [
{
"index": choice0.index,
"message": {
"role": "assistant",
"type": "text",
"content": choice0.text
},
"finish_reason": choice0.finish_reason
}
],
"usage": {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens
}
}
-40
View File
@@ -1,40 +0,0 @@
# 定义不同接口请求的模型
import logging
import openai
from ...utils import context
class RequestBase:
client: openai.Client
req_func: callable
def __init__(self, *args, **kwargs):
raise NotImplementedError
def _next_key(self):
switched, name = context.get_openai_manager().key_mgr.auto_switch()
logging.debug("切换api-key: switched={}, name={}".format(switched, name))
self.client.api_key = context.get_openai_manager().key_mgr.get_using_key()
def _req(self, **kwargs):
"""处理代理问题"""
logging.debug("请求接口参数: %s", str(kwargs))
config = context.get_config_manager().data
ret = self.req_func(**kwargs)
logging.debug("接口请求返回:%s", str(ret))
if config['switch_strategy'] == 'active':
self._next_key()
return ret
def __iter__(self):
raise self
def __next__(self):
raise NotImplementedError
+1 -2
View File
@@ -6,7 +6,6 @@ import typing
from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
from ..session import entities as session_entities
class LLMAPIRequester(metaclass=abc.ABCMeta):
"""LLM API请求器
@@ -24,7 +23,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
async def request(
self,
query: core_entities.Query,
conversation: session_entities.Conversation,
conversation: core_entities.Conversation,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求
"""
+7 -31
View File
@@ -10,7 +10,6 @@ import openai.types.chat.chat_completion as chat_completion
from .. import api
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...session import entities as session_entities
class OpenAIChatCompletion(api.LLMAPIRequester):
@@ -43,41 +42,18 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
async def _closure(
self,
req_messages: list[dict],
conversation: session_entities.Conversation,
user_text: str = None,
function_ret: str = None,
conversation: core_entities.Conversation,
) -> llm_entities.Message:
self.client.api_key = conversation.use_model.token_mgr.get_token()
args = self.ap.cfg_mgr.data["completion_api_params"].copy()
args["model"] = conversation.use_model.name
tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation)
# tools = [
# {
# "type": "function",
# "function": {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA",
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"],
# },
# },
# "required": ["location"],
# },
# },
# }
# ]
if tools:
args["tools"] = tools
if conversation.use_model.tool_call_supported:
tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation)
if tools:
args["tools"] = tools
# 设置此次请求中的messages
messages = req_messages
@@ -92,7 +68,7 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
return message
async def request(
self, query: core_entities.Query, conversation: session_entities.Conversation
self, query: core_entities.Query, conversation: core_entities.Conversation
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求"""
+6 -2
View File
@@ -1,9 +1,11 @@
from __future__ import annotations
import typing
import pydantic
from . import api
from . import token
from . import token, tokenizer
class LLMModelInfo(pydantic.BaseModel):
@@ -17,7 +19,9 @@ class LLMModelInfo(pydantic.BaseModel):
requester: api.LLMAPIRequester
function_call_supported: typing.Optional[bool] = False
tokenizer: 'tokenizer.LLMTokenizer'
tool_call_supported: typing.Optional[bool] = False
class Config:
arbitrary_types_allowed = True
+13 -9
View File
@@ -5,6 +5,7 @@ from ...core import app
from .apis import chatcmpl
from . import token
from .tokenizers import tiktoken
class ModelManager:
@@ -17,25 +18,28 @@ class ModelManager:
self.ap = ap
self.model_list = []
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
"""通过名称获取模型
"""
for model in self.model_list:
if model.name == name:
return model
raise ValueError(f"Model {name} not found")
async def initialize(self):
openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap)
await openai_chat_completion.initialize()
openai_token_mgr = token.TokenManager(self.ap, list(self.ap.cfg_mgr.data['openai_config']['api_key'].values()))
tiktoken_tokenizer = tiktoken.Tiktoken(self.ap)
self.model_list.append(
entities.LLMModelInfo(
name="gpt-3.5-turbo",
provider="openai",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
function_call_supported=True
tool_call_supported=True,
tokenizer=tiktoken_tokenizer
)
)
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
"""通过名称获取模型
"""
for model in self.model_list:
if model.name == name:
return model
raise ValueError(f"Model {name} not found")
+29
View File
@@ -0,0 +1,29 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from .. import entities as llm_entities
from . import entities
class LLMTokenizer(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
"""初始化分词器
"""
pass
@abc.abstractmethod
async def count_token(
self,
messages: list[llm_entities.Message],
model: entities.LLMModelInfo
) -> int:
pass
@@ -0,0 +1,28 @@
from __future__ import annotations
import tiktoken
from .. import tokenizer
from ... import entities as llm_entities
from .. import entities
class Tiktoken(tokenizer.LLMTokenizer):
async def count_token(
self,
messages: list[llm_entities.Message],
model: entities.LLMModelInfo
) -> int:
try:
encoding = tiktoken.encoding_for_model(model.name)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for message in messages:
num_tokens += len(encoding.encode(message.role))
num_tokens += len(encoding.encode(message.content))
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
-53
View File
@@ -1,53 +0,0 @@
from __future__ import annotations
import datetime
import asyncio
import typing
import pydantic
from ..sysprompt import entities as sysprompt_entities
from .. import entities as llm_entities
from ..requester import entities
from ...core import entities as core_entities
from ..tools import entities as tools_entities
class Conversation(pydantic.BaseModel):
"""对话"""
prompt: sysprompt_entities.Prompt
messages: list[llm_entities.Message]
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
use_model: entities.LLMModelInfo
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
class Session(pydantic.BaseModel):
"""会话"""
launcher_type: core_entities.LauncherTypes
launcher_id: int
sender_id: typing.Optional[int] = 0
use_prompt_name: typing.Optional[str] = 'default'
using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = []
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
semaphore: typing.Optional[asyncio.Semaphore] = None
class Config:
arbitrary_types_allowed = True
+5 -6
View File
@@ -3,14 +3,13 @@ from __future__ import annotations
import asyncio
from ...core import app, entities as core_entities
from . import entities
class SessionManager:
ap: app.Application
session_list: list[entities.Session]
session_list: list[core_entities.Session]
def __init__(self, ap: app.Application):
self.ap = ap
@@ -19,14 +18,14 @@ class SessionManager:
async def initialize(self):
pass
async def get_session(self, query: core_entities.Query) -> entities.Session:
async def get_session(self, query: core_entities.Query) -> core_entities.Session:
"""获取会话
"""
for session in self.session_list:
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
return session
session = entities.Session(
session = core_entities.Session(
launcher_type=query.launcher_type,
launcher_id=query.launcher_id,
semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000),
@@ -34,12 +33,12 @@ class SessionManager:
self.session_list.append(session)
return session
async def get_conversation(self, session: entities.Session) -> entities.Conversation:
async def get_conversation(self, session: core_entities.Session) -> core_entities.Conversation:
if not session.conversations:
session.conversations = []
if session.using_conversation is None:
conversation = entities.Conversation(
conversation = core_entities.Conversation(
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
messages=[],
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']),
+2
View File
@@ -6,6 +6,8 @@ import asyncio
import pydantic
from ...core import entities as core_entities
class LLMFunction(pydantic.BaseModel):
"""函数"""
+8 -26
View File
@@ -4,7 +4,6 @@ import typing
from ...core import app, entities as core_entities
from . import entities
from ..session import entities as session_entities
class ToolManager:
@@ -12,8 +11,6 @@ class ToolManager:
"""
ap: app.Application
all_functions: list[entities.LLMFunction]
def __init__(self, ap: app.Application):
self.ap = ap
@@ -22,30 +19,10 @@ class ToolManager:
async def initialize(self):
pass
def register_legacy_function(self, name: str, description: str, parameters: dict, func: callable):
"""注册函数
"""
async def wrapper(query, **kwargs):
return func(**kwargs)
function = entities.LLMFunction(
name=name,
description=description,
human_desc='',
enable=True,
parameters=parameters,
func=wrapper
)
self.all_functions.append(function)
async def register_function(self, function: entities.LLMFunction):
"""添加函数
"""
self.all_functions.append(function)
async def get_function(self, name: str) -> entities.LLMFunction:
"""获取函数
"""
for function in self.all_functions:
for function in await self.get_all_functions():
if function.name == name:
return function
return None
@@ -53,9 +30,14 @@ class ToolManager:
async def get_all_functions(self) -> list[entities.LLMFunction]:
"""获取所有函数
"""
return self.all_functions
all_functions: list[entities.LLMFunction] = []
async def generate_tools_for_openai(self, conversation: session_entities.Conversation) -> str:
for plugin in self.ap.plugin_mgr.plugins:
all_functions.extend(plugin.content_functions)
return all_functions
async def generate_tools_for_openai(self, conversation: core_entities.Conversation) -> str:
"""生成函数列表
"""
tools = []