From b1a2d21ee904c68167869701a4bc5a678d793eb4 Mon Sep 17 00:00:00 2001 From: LINSTCL Date: Sun, 5 Mar 2023 13:52:43 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=BC=82=E5=B8=B8=E5=A4=84?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/openai/modelmgr.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/pkg/openai/modelmgr.py b/pkg/openai/modelmgr.py index 64d427a1..2c8937c4 100644 --- a/pkg/openai/modelmgr.py +++ b/pkg/openai/modelmgr.py @@ -1,5 +1,6 @@ # 提供与模型交互的抽象接口 import openai, logging, threading, asyncio +import openai.error as aiE COMPLETION_MODELS = { 'text-davinci-003', @@ -29,22 +30,35 @@ class ModelRequest(): """GPT父类""" can_chat = False runtime:threading.Thread = None - ret = "" + ret = {} proxy:str = None + request_ready = True + error_info:str = "若在没有任何错误的情况下看到这句话,请带着配置文件上报Issues" - def __init__(self, model_name, user_name, request_fun, http_proxy:str = None): + def __init__(self, model_name, user_name, request_fun, http_proxy:str = None, time_out = None): self.model_name = model_name self.user_name = user_name self.request_fun = request_fun + self.time_out = time_out if http_proxy != None: self.proxy = http_proxy openai.proxy = self.proxy + self.request_ready = False async def __a_request__(self, **kwargs): - self.ret = await self.request_fun(**kwargs) + try: + self.ret:dict = await self.request_fun(**kwargs) + self.request_ready = True + except aiE.APIConnectionError as e: + self.error_info = "{}\n请检查网络连接或代理是否正常".format(e) + raise ConnectionError(self.error_info) + except Exception as e: + self.error_info = "{}\n由于请求异常产生的未知错误,请查看日志".format(e) + raise Exception(self.error_info) def request(self, **kwargs): if self.proxy != None: #异步请求 + self.request_ready = False loop = asyncio.new_event_loop() self.runtime = threading.Thread( target=loop.run_until_complete, @@ -64,13 +78,15 @@ class ModelRequest(): 若重写该方法,应检查异步线程状态,或在需要检查处super该方法 ''' if self.runtime != None and isinstance(self.runtime, threading.Thread): - self.runtime.join() - return + self.runtime.join(self.time_out) + if self.request_ready: + return + raise Exception(self.error_info) def get_total_tokens(self): try: return self.ret['usage']['total_tokens'] - except Exception: + except: return 0 def get_message(self):