From c57642bd4e42b7bafbd7968432654b4f435a4aae Mon Sep 17 00:00:00 2001 From: LINSTCL Date: Fri, 3 Mar 2023 14:12:53 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=B7=BB=E5=8A=A0proxy=E4=BB=A3=E7=90=86?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config-template.py | 3 ++ pkg/openai/manager.py | 6 ++- pkg/openai/modelmgr.py | 96 +++++++++++++++++++++++++++--------------- 3 files changed, 71 insertions(+), 34 deletions(-) diff --git a/config-template.py b/config-template.py index fbebb710..c7a29dd3 100644 --- a/config-template.py +++ b/config-template.py @@ -20,6 +20,7 @@ mirai_http_api_config = { # [必需] OpenAI的配置 # api_key: OpenAI的API Key +# http_proxy: 请求OpenAI时使用的代理,None为不使用,https和socks5暂不能使用 # 若只有一个api-key,请直接修改以下内容中的"openai_api_key"为你的api-key # # 如准备了多个api-key,可以以字典的形式填写,程序会自动选择可用的api-key @@ -30,11 +31,13 @@ mirai_http_api_config = { # "key1": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # "key2": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", # }, +# "http_proxy": "http://127.0.0.1:12345" # } openai_config = { "api_key": { "default": "openai_api_key" }, + "http_proxy": None } # [必需] 管理员QQ号,用于接收报错等通知及执行管理员级别指令 diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index 3bd0c275..cfafbd84 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -36,7 +36,11 @@ class OpenAIInteract: config = pkg.utils.context.get_config() # 根据模型选择使用的接口 - ai: ModelRequest = create_openai_model_request(config.completion_api_params['model'], 'user') + ai: ModelRequest = create_openai_model_request( + config.completion_api_params['model'], + 'user', + config.openai_config["http_proxy"] + ) ai.request( prompts, **config.completion_api_params diff --git a/pkg/openai/modelmgr.py b/pkg/openai/modelmgr.py index 0a68fac4..4bffc7b4 100644 --- a/pkg/openai/modelmgr.py +++ b/pkg/openai/modelmgr.py @@ -1,5 +1,5 @@ # 提供与模型交互的抽象接口 -import openai, logging +import openai, logging, threading, asyncio COMPLETION_MODELS = { 'text-davinci-003', @@ -26,48 +26,74 @@ IMAGE_MODELS = { class ModelRequest(): - """模型请求抽象类""" + """GPT父类""" can_chat = False + runtime:threading.Thread = None + ret = "" + proxy:str = None - def __init__(self, model_name, user_name, request_fun): + def __init__(self, model_name, user_name, request_fun, http_proxy:str = None): self.model_name = model_name self.user_name = user_name self.request_fun = request_fun + if http_proxy != None: + self.proxy = http_proxy + openai.proxy = self.proxy + + async def __a_request__(self, **kwargs): + self.ret = await self.request_fun(**kwargs) def request(self, **kwargs): - ret = self.request_fun(**kwargs) - self.ret = self.ret_handle(ret) - self.message = self.ret["choices"][0]["message"] + if self.proxy != None: #异步请求 + self.runtime = threading.Thread( + target=asyncio.run, + args=(self.__a_request__(**kwargs),) + ) + self.runtime.start() + else: #同步请求 + self.ret = self.request_fun(**kwargs) def __msg_handle__(self, msg): """将prompt dict转换成接口需要的格式""" return msg def ret_handle(self): + ''' + API消息返回处理函数 + 若重写该方法,应检查异步线程状态,或在需要检查处super该方法 + ''' + if self.runtime != None and isinstance(self.runtime, threading.Thread): + self.runtime.join() return - + def get_total_tokens(self): - return self.ret['usage']['total_tokens'] - + try: + return self.ret['usage']['total_tokens'] + except Exception: + return 0 + def get_message(self): return self.message - + def get_response(self): return self.ret - class ChatCompletionModel(ModelRequest): - """ChatCompletion接口实现""" + """ChatCompletion类模型""" Chat_role = ['system', 'user', 'assistant'] - def __init__(self, model_name, user_name): - request_fun = openai.ChatCompletion.create + def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs): + if http_proxy == None: + request_fun = openai.ChatCompletion.create + else: + request_fun = openai.ChatCompletion.acreate self.can_chat = True - super().__init__(model_name, user_name, request_fun) + super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs) def request(self, prompts, **kwargs): - self.ret = self.request_fun(messages = self.__msg_handle__(prompts), **kwargs, user=self.user_name) + prompts = self.__msg_handle__(prompts) + kwargs['messages'] = prompts + super().request(**kwargs) self.ret_handle() - self.message = self.ret["choices"][0]["message"]['content'] def __msg_handle__(self, msgs): temp_msgs = [] @@ -76,20 +102,24 @@ class ChatCompletionModel(ModelRequest): temp_msgs.append(msg.copy()) return temp_msgs - def get_content(self): - return self.message - + def get_message(self): + return self.ret["choices"][0]["message"]['content'] #需要时直接加载加快请求速度,降低内存消耗 + class CompletionModel(ModelRequest): - """Completion接口实现""" - def __init__(self, model_name, user_name): - request_fun = openai.Completion.create - super().__init__(model_name, user_name, request_fun) + """Completion类模型""" + def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs): + if http_proxy == None: + request_fun = openai.Completion.create + else: + request_fun = openai.Completion.acreate + super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs) def request(self, prompts, **kwargs): - self.ret = self.request_fun(prompt = self.__msg_handle__(prompts), **kwargs) + prompts = self.__msg_handle__(prompts) + kwargs['prompt'] = prompts + super().request(**kwargs) self.ret_handle() - self.message = self.ret["choices"][0]["text"] def __msg_handle__(self, msgs): prompt = '' @@ -102,17 +132,17 @@ class CompletionModel(ModelRequest): # prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content']) prompt = prompt + "assistant: " return prompt - - def get_text(self): - return self.message - -def create_openai_model_request(model_name: str, user_name: str = 'user') -> ModelRequest: + def get_message(self): + return self.ret["choices"][0]["text"] + + +def create_openai_model_request(model_name: str, user_name: str = 'user', http_proxy:str = None) -> ModelRequest: """使用给定的模型名称创建模型请求对象""" if model_name in CHAT_COMPLETION_MODELS: - model = ChatCompletionModel(model_name, user_name) + model = ChatCompletionModel(model_name, user_name, http_proxy) elif model_name in COMPLETION_MODELS: - model = CompletionModel(model_name, user_name) + model = CompletionModel(model_name, user_name, http_proxy) else : log = "找不到模型[{}],请检查配置文件".format(model_name) logging.error(log) From c23d114094561c499be4f289a3bf3aa1d35074c6 Mon Sep 17 00:00:00 2001 From: LINSTCL Date: Fri, 3 Mar 2023 15:20:42 +0800 Subject: [PATCH 2/2] =?UTF-8?q?proxy=E5=90=8E=E5=90=91=E5=85=BC=E5=AE=B9?= =?UTF-8?q?=EF=BC=8C=E4=BF=AE=E5=A4=8D=E9=83=A8=E5=88=86=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/openai/manager.py | 2 +- pkg/openai/modelmgr.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index cfafbd84..e5cef33d 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -39,7 +39,7 @@ class OpenAIInteract: ai: ModelRequest = create_openai_model_request( config.completion_api_params['model'], 'user', - config.openai_config["http_proxy"] + config.openai_config["http_proxy"] if "http_proxy" in config.openai_config else None ) ai.request( prompts, diff --git a/pkg/openai/modelmgr.py b/pkg/openai/modelmgr.py index 4bffc7b4..64d427a1 100644 --- a/pkg/openai/modelmgr.py +++ b/pkg/openai/modelmgr.py @@ -45,8 +45,9 @@ class ModelRequest(): def request(self, **kwargs): if self.proxy != None: #异步请求 + loop = asyncio.new_event_loop() self.runtime = threading.Thread( - target=asyncio.run, + target=loop.run_until_complete, args=(self.__a_request__(**kwargs),) ) self.runtime.start()