diff --git a/README.md b/README.md index 8e4547d6..f1d3c770 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,7 @@ python3 main.py - [@mikumifa](https://github.com/mikumifa) 本项目Docker部署仓库开发者 - [@dominoar](https://github.com/dominoar) 为本项目开发多种插件 - [@hissincn](https://github.com/hissincn) 本项目贡献者 +- [@LINSTCL](https://github.com/LINSTCL) GPT-3.5官方模型适配贡献者 以及其他所有为本项目提供支持的朋友们。 diff --git a/config-template.py b/config-template.py index f0c8e275..fbebb710 100644 --- a/config-template.py +++ b/config-template.py @@ -112,10 +112,23 @@ encourage_sponsor_at_start = True # 注意:较大的prompt_submit_length会导致OpenAI账户额度消耗更快 prompt_submit_length = 1024 -# OpenAI的completion API的参数 +# OpenAI补全API的参数 +# 请在下方填写模型,程序自动选择接口 +# 现已支持的模型有: +# +# 'gpt-3.5-turbo' +# 'gpt-3.5-turbo-0301' +# 'text-davinci-003' +# 'text-davinci-002' +# 'code-davinci-002' +# 'code-cushman-001' +# 'text-curie-001' +# 'text-babbage-001' +# 'text-ada-001' +# # 具体请查看OpenAI的文档: https://beta.openai.com/docs/api-reference/completions/create completion_api_params = { - "model": "text-davinci-003", + "model": "gpt-3.5-turbo", "temperature": 0.9, # 数值越低得到的回答越理性,取值范围[0, 1] "max_tokens": 512, # 每次获取OpenAI接口响应的文字量上限, 不高于4096 "top_p": 1, # 生成的文本的文本与要求的符合度, 取值范围[0, 1] @@ -138,14 +151,6 @@ include_image_description = True # 消息处理的超时时间,单位为秒 process_message_timeout = 30 -# 会话对象名称,此配置与会话对象管理相关, -# 若不了解相关功能,无需修改此配置 -# 详细说明请查看:https://github.com/RockChinQ/QChatGPT/wiki/%E6%8A%80%E6%9C%AF%E4%BF%A1%E6%81%AF#%E4%BC%9A%E8%AF%9Dsession -# user_name: 管理员(主人)的名字 -# bot_name: 机器人的名字 -user_name = 'You' -bot_name = 'Bot' - # [暂未实现] 群内会话是否启用多对象名称 # 若不启用,群内会话的prompt只使用user_name和bot_name multi_subject = False diff --git a/pkg/database/manager.py b/pkg/database/manager.py index a51ab509..519893bb 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -206,7 +206,7 @@ class DatabaseManager: } # 列出与某个对象的所有对话session - def list_history(self, session_name: str, capacity: int, page: int, replace: str = ""): + def list_history(self, session_name: str, capacity: int, page: int): self.execute(""" select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {} @@ -227,7 +227,7 @@ class DatabaseManager: 'subject_number': subject_number, 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt if replace == "" else prompt.replace(replace, "") + 'prompt': prompt }) return sessions diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index c64d21da..3bd0c275 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -5,7 +5,7 @@ import openai import pkg.openai.keymgr import pkg.utils.context import pkg.audit.gatherer - +from pkg.openai.modelmgr import ModelRequest, create_openai_model_request # 为其他模块提供与OpenAI交互的接口 class OpenAIInteract: @@ -32,24 +32,27 @@ class OpenAIInteract: pkg.utils.context.set_openai_manager(self) # 请求OpenAI Completion - def request_completion(self, prompt, stop): + def request_completion(self, prompts): config = pkg.utils.context.get_config() - response = openai.Completion.create( - prompt=prompt, - stop=stop, + + # 根据模型选择使用的接口 + ai: ModelRequest = create_openai_model_request(config.completion_api_params['model'], 'user') + ai.request( + prompts, **config.completion_api_params ) + response = ai.get_response() logging.debug("OpenAI response: %s", response) if 'model' in config.completion_api_params: self.audit_mgr.report_text_model_usage(config.completion_api_params['model'], - response['usage']['total_tokens']) + ai.get_total_tokens()) elif 'engine' in config.completion_api_params: self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'], response['usage']['total_tokens']) - return response + return ai.get_message() def request_image(self, prompt): diff --git a/pkg/openai/modelmgr.py b/pkg/openai/modelmgr.py index c106824e..0a68fac4 100644 --- a/pkg/openai/modelmgr.py +++ b/pkg/openai/modelmgr.py @@ -1,7 +1,19 @@ # 提供与模型交互的抽象接口 +import openai, logging COMPLETION_MODELS = { - 'text-davinci-003' + 'text-davinci-003', + 'text-davinci-002', + 'code-davinci-002', + 'code-cushman-001', + 'text-curie-001', + 'text-babbage-001', + 'text-ada-001', +} + +CHAT_COMPLETION_MODELS = { + 'gpt-3.5-turbo', + 'gpt-3.5-turbo-0301', } EDIT_MODELS = { @@ -13,22 +25,97 @@ IMAGE_MODELS = { } -# ModelManager -# 由session包含 -class ModelMgr(object): +class ModelRequest(): + """模型请求抽象类""" + can_chat = False - using_completion_model = "" - using_edit_model = "" - using_image_model = "" + def __init__(self, model_name, user_name, request_fun): + self.model_name = model_name + self.user_name = user_name + self.request_fun = request_fun - def __init__(self): - pass + def request(self, **kwargs): + ret = self.request_fun(**kwargs) + self.ret = self.ret_handle(ret) + self.message = self.ret["choices"][0]["message"] - def get_using_completion_model(self): - return self.using_completion_model + def __msg_handle__(self, msg): + """将prompt dict转换成接口需要的格式""" + return msg + + def ret_handle(self): + return + + def get_total_tokens(self): + return self.ret['usage']['total_tokens'] + + def get_message(self): + return self.message + + def get_response(self): + return self.ret + - def get_using_edit_model(self): - return self.using_edit_model +class ChatCompletionModel(ModelRequest): + """ChatCompletion接口实现""" + Chat_role = ['system', 'user', 'assistant'] + def __init__(self, model_name, user_name): + request_fun = openai.ChatCompletion.create + self.can_chat = True + super().__init__(model_name, user_name, request_fun) - def get_using_image_model(self): - return self.using_image_model + def request(self, prompts, **kwargs): + self.ret = self.request_fun(messages = self.__msg_handle__(prompts), **kwargs, user=self.user_name) + self.ret_handle() + self.message = self.ret["choices"][0]["message"]['content'] + + def __msg_handle__(self, msgs): + temp_msgs = [] + # 把msgs拷贝进temp_msgs + for msg in msgs: + temp_msgs.append(msg.copy()) + return temp_msgs + + def get_content(self): + return self.message + + +class CompletionModel(ModelRequest): + """Completion接口实现""" + def __init__(self, model_name, user_name): + request_fun = openai.Completion.create + super().__init__(model_name, user_name, request_fun) + + def request(self, prompts, **kwargs): + self.ret = self.request_fun(prompt = self.__msg_handle__(prompts), **kwargs) + self.ret_handle() + self.message = self.ret["choices"][0]["text"] + + def __msg_handle__(self, msgs): + prompt = '' + for msg in msgs: + prompt = prompt + "{}: {}\n".format(msg['role'], msg['content']) + # for msg in msgs: + # if msg['role'] == 'assistant': + # prompt = prompt + "{}\n".format(msg['content']) + # else: + # prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content']) + prompt = prompt + "assistant: " + return prompt + + def get_text(self): + return self.message + + +def create_openai_model_request(model_name: str, user_name: str = 'user') -> ModelRequest: + """使用给定的模型名称创建模型请求对象""" + if model_name in CHAT_COMPLETION_MODELS: + model = ChatCompletionModel(model_name, user_name) + elif model_name in COMPLETION_MODELS: + model = CompletionModel(model_name, user_name) + else : + log = "找不到模型[{}],请检查配置文件".format(model_name) + logging.error(log) + raise IndexError(log) + logging.debug("使用接口[{}]创建模型请求[{}]".format(model.__class__.__name__, model_name)) + return model diff --git a/pkg/openai/session.py b/pkg/openai/session.py index c04abc4e..9bd9e72d 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -1,8 +1,10 @@ import logging import threading import time +import json import pkg.openai.manager +import pkg.openai.modelmgr import pkg.database.manager import pkg.utils.context @@ -17,6 +19,32 @@ class SessionOfflineStatus: ON_GOING = 'on_going' EXPLICITLY_CLOSED = 'explicitly_closed' +# 重置session.prompt +def reset_session_prompt(session_name, prompt): + # 备份原始数据 + bak_path = 'logs/{}-{}.bak'.format( + session_name, + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + ) + f = open(bak_path, 'w+') + f.write(prompt) + f.close() + # 生成新数据 + config = pkg.utils.context.get_config() + prompt = [ + { + 'role': 'system', + 'content': config.default_prompt['default'] + } + ] + # 警告 + logging.warning( + """ +用户[{}]的数据已被重置,有可能是因为数据版本过旧或存储错误 +原始数据将备份在: +{}""".format(session_name, bak_path) + ) + return prompt # 从数据加载session def load_sessions(): @@ -33,7 +61,11 @@ def load_sessions(): temp_session.name = session_name temp_session.create_timestamp = session_data[session_name]['create_timestamp'] temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp'] - temp_session.prompt = session_data[session_name]['prompt'] + try: + temp_session.prompt = json.loads(session_data[session_name]['prompt']) + except Exception: + temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt']) + temp_session.persistence() sessions[session_name] = temp_session @@ -60,12 +92,7 @@ def dump_session(session_name: str): class Session: name = '' - prompt = "" - - import config - - user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You' - bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot' + prompt = [] create_timestamp = 0 @@ -99,11 +126,15 @@ class Session: else: current_default_prompt = dprompt.get_prompt(use_default) - user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You' - bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot' - - return (user_name + ":{}\n".format(current_default_prompt) + bot_name + ":好的\n") \ - if current_default_prompt != '' else '' + return [ + { + 'role': 'user', + 'content': current_default_prompt + },{ + 'role': 'assistant', + 'content': 'ok' + } + ] def __init__(self, name: str): self.name = name @@ -165,22 +196,16 @@ class Session: if event.is_prevented_default(): return None - # max_rounds = config.prompt_submit_round_amount if hasattr(config, 'prompt_submit_round_amount') else 7 config = pkg.utils.context.get_config() - max_rounds = 1000 # 不再限制回合数 max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024 # 向API请求补全 - response = pkg.utils.context.get_openai_manager().request_completion( - self.cut_out(self.prompt + self.user_name + ':' + - text + '\n' + self.bot_name + ':', - max_rounds, max_length), - self.user_name + ':') + message = pkg.utils.context.get_openai_manager().request_completion( + self.cut_out(text, max_length), + ) - self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':' - # print(response) - # 处理回复 - res_test = response["choices"][0]["text"] + # 成功获取,处理回复 + res_test = message res_ans = res_test # 去除开头可能的提示 @@ -189,50 +214,56 @@ class Session: del (res_ans_spt[0]) res_ans = '\n\n'.join(res_ans_spt) - self.prompt += "{}".format(res_ans) + '\n' + # 将此次对话的双方内容加入到prompt中 + self.prompt.append({'role':'user', 'content':text}) + self.prompt.append({'role':'assistant', 'content':res_ans}) if self.just_switched_to_exist_session: self.just_switched_to_exist_session = False self.set_ongoing() - return res_ans + return res_ans if res_ans[0]!='\n' else res_ans[1:] # 删除上一回合并返回上一回合的问题 def undo(self) -> str: self.last_interact_timestamp = int(time.time()) # 删除上一回合 - to_delete = self.cut_out(self.prompt, 1, 1024) - - self.prompt = self.prompt.replace(to_delete, '') + if self.prompt[-1]['role'] != 'user': + res = self.prompt[-1]['content'] + self.prompt.remove(self.prompt[-2]) + else: + res = self.prompt[-2]['content'] + self.prompt.remove(self.prompt[-1]) # 返回上一回合的问题 - return to_delete.split(self.bot_name + ':')[0].split(self.user_name + ':')[1].strip() + return res - # 从尾部截取prompt里不多于max_rounds个回合,长度不大于max_tokens的字符串 - # 保证都是完整的对话 - def cut_out(self, prompt: str, max_rounds: int, max_tokens: int) -> str: - # 分隔出每个回合 - rounds_spt_by_user_name = prompt.split(self.user_name + ':') + # 构建对话体 + def cut_out(self, msg: str, max_tokens: int) -> list: + """将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens""" + # 如果用户消息长度超过max_tokens,直接返回 + + temp_prompt = [ + { + 'role': 'user', + 'content': msg + } + ] - result = '' - - checked_rounds = 0 - # 从后往前遍历,加到result前面,检查result是否符合要求 - for i in range(len(rounds_spt_by_user_name) - 1, 0, -1): - result_temp = self.user_name + ':' + rounds_spt_by_user_name[i] + result - checked_rounds += 1 - - if checked_rounds > max_rounds: + token_count = len(msg) + # 倒序遍历prompt + for i in range(len(self.prompt) - 1, -1, -1): + if token_count >= max_tokens: break - if int((len(result_temp.encode('utf-8')) - len(result_temp)) / 2 + len(result_temp)) > max_tokens: - break + # 将prompt加到temp_prompt头部 + temp_prompt.insert(0, self.prompt[i]) + token_count += len(self.prompt[i]['content']) - result = result_temp + logging.debug('cut_out: {}'.format(str(temp_prompt))) - logging.debug('cut_out: {}'.format(result)) - return result + return temp_prompt # 持久化session def persistence(self): @@ -247,11 +278,11 @@ class Session: subject_number = int(name_spt[1]) db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp, - self.prompt) + json.dumps(self.prompt)) # 重置session def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None): - if not self.prompt.endswith(':好的\n'): + if self.prompt[-1]['role'] != "system": self.persistence() if explicit: # 触发插件事件 @@ -291,7 +322,11 @@ class Session: self.create_timestamp = last_one['create_timestamp'] self.last_interact_timestamp = last_one['last_interact_timestamp'] - self.prompt = last_one['prompt'] + try: + self.prompt = json.loads(last_one['prompt']) + except json.decoder.JSONDecodeError: + self.prompt = reset_session_prompt(self.name, last_one['prompt']) + self.persistence() self.just_switched_to_exist_session = True return self @@ -306,14 +341,17 @@ class Session: self.create_timestamp = next_one['create_timestamp'] self.last_interact_timestamp = next_one['last_interact_timestamp'] - self.prompt = next_one['prompt'] + try: + self.prompt = json.loads(next_one['prompt']) + except json.decoder.JSONDecodeError: + self.prompt = reset_session_prompt(self.name, next_one['prompt']) + self.persistence() self.just_switched_to_exist_session = True return self def list_history(self, capacity: int = 10, page: int = 0): - return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page, - self.get_default_prompt()) + return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page) def draw_image(self, prompt: str): return pkg.utils.context.get_openai_manager().request_image(prompt) diff --git a/pkg/qqbot/command.py b/pkg/qqbot/command.py index 7de7e5a7..b174d453 100644 --- a/pkg/qqbot/command.py +++ b/pkg/qqbot/command.py @@ -185,11 +185,7 @@ def process_command(session_name: str, text_message: str, mgr, config, else: datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime( '%Y-%m-%d %H:%M:%S') - reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format( - datetime_str) + result.prompt[ - :min(100, - len(result.prompt))] + \ - ("..." if len(result.prompt) > 100 else "#END#")] + reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format(datetime_str)] elif cmd == 'next': result = pkg.openai.session.get_session(session_name).next_session() if result is None: @@ -197,13 +193,18 @@ def process_command(session_name: str, text_message: str, mgr, config, else: datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime( '%Y-%m-%d %H:%M:%S') - reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format( - datetime_str) + result.prompt[ - :min(100, - len(result.prompt))] + \ - ("..." if len(result.prompt) > 100 else "#END#")] + reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format(datetime_str)] elif cmd == 'prompt': - reply = ["[bot]当前对话所有内容:\n" + pkg.openai.session.get_session(session_name).prompt] + msgs = "" + session:list = pkg.openai.session.get_session(session_name).prompt + for msg in session: + if len(params) != 0 and params[0] in ['-all', '-a']: + msgs = msgs + "{}: {}\n\n".format(msg['role'], msg['content']) + elif len(msg['content']) > 30: + msgs = msgs + "[{}]: {}...\n\n".format(msg['role'], msg['content'][:30]) + else: + msgs = msgs + "[{}]: {}\n\n".format(msg['role'], msg['content']) + reply = ["[bot]当前对话所有内容:\n{}".format(msgs)] elif cmd == 'list': pkg.openai.session.get_session(session_name).persistence() page = 0 @@ -223,10 +224,21 @@ def process_command(session_name: str, text_message: str, mgr, config, for i in range(len(results)): # 时间(使用create_timestamp转换) 序号 部分内容 datetime_obj = datetime.datetime.fromtimestamp(results[i]['create_timestamp']) - reply_str += "#{} 创建:{} {}\n".format(i + page * 10, - datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), - results[i]['prompt'][ - :min(20, len(results[i]['prompt']))]) + msg = "" + try: + msg = json.loads(results[i]['prompt']) + except json.decoder.JSONDecodeError: + msg = pkg.openai.session.reset_session_prompt(session_name, results[i]['prompt']) + # 持久化 + pkg.openai.session.get_session(session_name).persistence() + if len(msg) >= 2: + reply_str += "#{} 创建:{} {}\n".format(i + page * 10, + datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), + msg[1]['content']) + else: + reply_str += "#{} 创建:{} {}\n".format(i + page * 10, + datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), + "无内容") if results[i]['create_timestamp'] == pkg.openai.session.get_session( session_name).create_timestamp: current = i + page * 10 diff --git a/requirements.txt b/requirements.txt index 7d445758..838279ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ requests~=2.28.1 -openai~=0.26.5 +openai~=0.27.0 pip~=22.3.1 dulwich~=0.21.3 colorlog~=6.6.0 diff --git a/tests/compatibility_tests/models_and_interfaces.py b/tests/compatibility_tests/models_and_interfaces.py new file mode 100644 index 00000000..1ace18d4 --- /dev/null +++ b/tests/compatibility_tests/models_and_interfaces.py @@ -0,0 +1,46 @@ +import openai +import time + +# 测试completion api +models = [ + 'gpt-3.5-turbo', + 'gpt-3.5-turbo-0301', + 'text-davinci-003', + 'text-davinci-002', + 'code-davinci-002', + 'code-cushman-001', + 'text-curie-001', + 'text-babbage-001', + 'text-ada-001', +] + +openai.api_key = "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" + +for model in models: + print('Testing model: ', model) + + # completion api + try: + response = openai.Completion.create( + model=model, + prompt="Say this is a test", + max_tokens=7, + temperature=0 + ) + print(' completion api: ', response['choices'][0]['text'].strip()) + except Exception as e: + print(' completion api err: ', e) + + # chat completion api + try: + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "user", "content": "Hello!"} + ] + ) + print(" chat api: ",completion.choices[0].message['content'].strip()) + except Exception as e: + print(' chat api err: ', e) + + time.sleep(60)