diff --git a/config-template.py b/config-template.py index 838a2b24..c62dc0e5 100644 --- a/config-template.py +++ b/config-template.py @@ -77,6 +77,12 @@ completion_api_params = { "presence_penalty": 1.0, } +# OpenAI的Image API的参数 +# 具体请查看OpenAI的文档: https://beta.openai.com/docs/api-reference/images/create +image_api_params = { + "size": "256x256", +} + # 消息处理的超时时间,单位为秒 process_message_timeout = 15 diff --git a/main.py b/main.py index 988626fc..7352942c 100644 --- a/main.py +++ b/main.py @@ -59,7 +59,7 @@ def main(): database.initialize_database() - openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key'], config.completion_api_params) + openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key']) # 加载所有未超时的session pkg.openai.session.load_sessions() diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index e753953f..e9279cfd 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -15,9 +15,12 @@ class OpenAIInteract: key_mgr = None - def __init__(self, api_key: str, api_params: dict): + default_image_api_params = { + "size": "256x256", + } + + def __init__(self, api_key: str): # self.api_key = api_key - self.api_params = api_params self.key_mgr = pkg.openai.keymgr.KeysManager(api_key) @@ -28,12 +31,11 @@ class OpenAIInteract: # 请求OpenAI Completion def request_completion(self, prompt, stop): - logging.debug("请求OpenAI Completion, key:"+openai.api_key) response = openai.Completion.create( prompt=prompt, stop=stop, timeout=config.process_message_timeout, - **self.api_params + **config.completion_api_params ) switched = self.key_mgr.report_usage(prompt + response['choices'][0]['text']) if switched: @@ -41,6 +43,15 @@ class OpenAIInteract: return response + def request_image(self, prompt): + response = openai.Image.create( + prompt=prompt, + n=1, + **config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params + ) + + return response + def get_inst() -> OpenAIInteract: global inst diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 2d28f38a..9bc9e9e1 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -268,3 +268,6 @@ class Session: def list_history(self, capacity: int = 10, page: int = 0): return pkg.database.manager.get_inst().list_history(self.name, capacity, page, get_default_prompt()) + + def draw_image(self, prompt: str): + return pkg.openai.manager.get_inst().request_image(prompt) diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index 886f4fd8..7b09b850 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -6,6 +6,8 @@ from func_timeout import func_set_timeout import logging import openai +from mirai import Image + import config import pkg.openai.session @@ -128,6 +130,17 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str) -> reply_str += "\n当前使用:{}".format(using_key_name) reply = [reply_str] + + elif cmd == 'draw': + if len(params) == 0: + reply = ["[bot]err:请输入图片描述文字"] + else: + session = pkg.openai.session.get_session(session_name) + + res = session.draw_image(" ".join(params)) + + logging.debug("draw_image result:{}".format(res)) + reply = [Image(url=res['data'][0]['url'])] except Exception as e: mgr.notify_admin("{}指令执行失败:{}".format(session_name, e)) logging.exception(e)