feat: 内容函数全局开关支持

This commit is contained in:
RockChinQ
2023-07-29 16:28:18 +08:00
parent ae6994e241
commit 8c69b8a1d9
6 changed files with 48 additions and 163 deletions

View File

@@ -1,5 +1,6 @@
import openai
import json
import logging
from .model import RequestBase
@@ -86,7 +87,7 @@ class ChatCompletionRequest(RequestBase):
self.append_message(
role="assistant",
content="function call: "+json.dumps(self.pending_func_call)
content="function call: "+json.dumps(self.pending_func_call, ensure_ascii=False)
)
return {
@@ -147,12 +148,16 @@ class ChatCompletionRequest(RequestBase):
func_schema['parameters']['required'][0]: cp_pending_func_call['arguments']
}
logging.info("执行函数调用: name={}, arguments={}".format(func_name, arguments))
# 执行函数调用
ret = execute_function(func_name, arguments)
logging.info("函数执行完成。")
self.append_message(
role="function",
content=json.dumps(ret),
content=json.dumps(ret, ensure_ascii=False),
name=func_name
)
@@ -165,7 +170,7 @@ class ChatCompletionRequest(RequestBase):
"role": "function",
"type": "function_return",
"function_name": func_name,
"content": json.dumps(ret)
"content": json.dumps(ret, ensure_ascii=False)
},
"finish_reason": "function_return"
}

View File

@@ -16,12 +16,16 @@ class RequestBase:
"""处理代理问题"""
ret: dict = {}
exception: Exception = None
async def awrapper(**kwargs):
nonlocal ret
nonlocal ret, exception
ret = await self.req_func(**kwargs)
return ret
try:
ret = await self.req_func(**kwargs)
return ret
except Exception as e:
exception = e
loop = asyncio.new_event_loop()
@@ -33,6 +37,9 @@ class RequestBase:
thr.start()
thr.join()
if exception is not None:
raise exception
return ret
def __iter__(self):

View File

@@ -2,7 +2,7 @@
import logging
from pkg.plugin.host import __callable_functions__, __function_inst_map__
from pkg.plugin import host
class ContentFunctionNotFoundError(Exception):
@@ -11,19 +11,21 @@ class ContentFunctionNotFoundError(Exception):
def get_func_schema_list() -> list:
"""从plugin包中的函数结构中获取并处理成受GPT支持的格式"""
if not host.__enable_content_functions__:
return []
schemas = __callable_functions__
schemas = host.__callable_functions__
return schemas
def get_func(name: str) -> callable:
if name not in __function_inst_map__:
if name not in host.__function_inst_map__:
raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name))
return __function_inst_map__[name]
return host.__function_inst_map__[name]
def get_func_schema(name: str) -> dict:
for func in __callable_functions__:
for func in host.__callable_functions__:
if func['name'] == name:
return func
raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name))

View File

@@ -43,157 +43,6 @@ IMAGE_MODELS = {
}
# class ModelRequest:
# """模型接口请求父类"""
# can_chat = False
# runtime: threading.Thread = None
# ret = {}
# proxy: str = None
# request_ready = True
# error_info: str = "若在没有任何错误的情况下看到这句话请带着配置文件上报Issues"
# def __init__(self, model_name, user_name, request_fun, http_proxy:str = None, time_out = None):
# self.model_name = model_name
# self.user_name = user_name
# self.request_fun = request_fun
# self.time_out = time_out
# if http_proxy != None:
# self.proxy = http_proxy
# openai.proxy = self.proxy
# self.request_ready = False
# async def __a_request__(self, **kwargs):
# """异步请求"""
# try:
# self.ret: dict = await self.request_fun(**kwargs)
# self.request_ready = True
# except aiE.APIConnectionError as e:
# self.error_info = "{}\n请检查网络连接或代理是否正常".format(e)
# raise ConnectionError(self.error_info)
# except ValueError as e:
# self.error_info = "{}\n该错误可能是由于http_proxy格式设置错误引起的"
# except Exception as e:
# self.error_info = "{}\n由于请求异常产生的未知错误请查看日志".format(e)
# raise type(e)(self.error_info)
# def request(self, **kwargs):
# """向接口发起请求"""
# if self.proxy != None: #异步请求
# self.request_ready = False
# loop = asyncio.new_event_loop()
# self.runtime = threading.Thread(
# target=loop.run_until_complete,
# args=(self.__a_request__(**kwargs),)
# )
# self.runtime.start()
# else: #同步请求
# self.ret = self.request_fun(**kwargs)
# def __msg_handle__(self, msg):
# """将prompt dict转换成接口需要的格式"""
# return msg
# def ret_handle(self):
# '''
# API消息返回处理函数
# 若重写该方法应检查异步线程状态或在需要检查处super该方法
# '''
# if self.runtime != None and isinstance(self.runtime, threading.Thread):
# self.runtime.join(self.time_out)
# if self.request_ready:
# return
# raise Exception(self.error_info)
# def get_total_tokens(self):
# try:
# return self.ret['usage']['total_tokens']
# except:
# return 0
# def get_message(self):
# return self.message
# def get_response(self):
# return self.ret
# class ChatCompletionModel(ModelRequest):
# """ChatCompletion接口的请求实现"""
# Chat_role = ['system', 'user', 'assistant']
# def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
# if http_proxy == None:
# request_fun = openai.ChatCompletion.create
# else:
# request_fun = openai.ChatCompletion.acreate
# self.can_chat = True
# super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs)
# def request(self, prompts, **kwargs):
# prompts = self.__msg_handle__(prompts)
# kwargs['messages'] = prompts
# super().request(**kwargs)
# self.ret_handle()
# def __msg_handle__(self, msgs):
# temp_msgs = []
# # 把msgs拷贝进temp_msgs
# for msg in msgs:
# temp_msgs.append(msg.copy())
# return temp_msgs
# def get_message(self):
# return self.ret["choices"][0]["message"]['content'] #需要时直接加载加快请求速度,降低内存消耗
# class CompletionModel(ModelRequest):
# """Completion接口的请求实现"""
# def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
# if http_proxy == None:
# request_fun = openai.Completion.create
# else:
# request_fun = openai.Completion.acreate
# super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs)
# def request(self, prompts, **kwargs):
# prompts = self.__msg_handle__(prompts)
# kwargs['prompt'] = prompts
# super().request(**kwargs)
# self.ret_handle()
# def __msg_handle__(self, msgs):
# prompt = ''
# for msg in msgs:
# prompt = prompt + "{}: {}\n".format(msg['role'], msg['content'])
# # for msg in msgs:
# # if msg['role'] == 'assistant':
# # prompt = prompt + "{}\n".format(msg['content'])
# # else:
# # prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content'])
# prompt = prompt + "assistant: "
# return prompt
# def get_message(self):
# return self.ret["choices"][0]["text"]
# def create_openai_model_request(model_name: str, user_name: str = 'user', http_proxy:str = None) -> ModelRequest:
# """使用给定的模型名称创建模型请求对象"""
# if model_name in CHAT_COMPLETION_MODELS:
# model = ChatCompletionModel(model_name, user_name, http_proxy)
# elif model_name in COMPLETION_MODELS:
# model = CompletionModel(model_name, user_name, http_proxy)
# else :
# log = "找不到模型[{}],请检查配置文件".format(model_name)
# logging.error(log)
# raise IndexError(log)
# logging.debug("使用接口[{}]创建模型请求[{}]".format(model.__class__.__name__, model_name))
# return model
def select_request_cls(model_name: str, messages: list, args: dict) -> RequestBase:
if model_name in CHAT_COMPLETION_MODELS:
return ChatCompletionRequest(model_name, messages, **args)

View File

@@ -44,6 +44,9 @@ __plugins__ = {}
__plugins_order__ = []
"""插件顺序"""
__enable_content_functions__ = True
"""是否启用内容函数"""
__callable_functions__ = []
"""供GPT调用的函数结构"""

View File

@@ -8,7 +8,10 @@ import logging
def wrapper_dict_from_runtime_context() -> dict:
"""从变量中包装settings.json的数据字典"""
settings = {
"order": []
"order": [],
"functions": {
"enable": host.__enable_content_functions__
}
}
for plugin_name in host.__plugins_order__:
@@ -22,6 +25,11 @@ def apply_settings(settings: dict):
if "order" in settings:
host.__plugins_order__ = settings["order"]
if "functions" in settings:
if "enable" in settings["functions"]:
host.__enable_content_functions__ = settings["functions"]["enable"]
# logging.debug("set content function enable: {}".format(host.__enable_content_functions__))
def dump_settings():
"""保存settings.json数据"""
@@ -78,6 +86,17 @@ def load_settings():
settings["order"].append(plugin_name)
settings_modified = True
if "functions" not in settings:
settings["functions"] = {
"enable": host.__enable_content_functions__
}
settings_modified = True
elif "enable" not in settings["functions"]:
settings["functions"]["enable"] = host.__enable_content_functions__
settings_modified = True
logging.info("已全局{}内容函数。".format("启用" if settings["functions"]["enable"] else "禁用"))
apply_settings(settings)
if settings_modified: