From 908cb1634ba1143808295bbbac4d9ced63697bd6 Mon Sep 17 00:00:00 2001 From: Rock Chin Date: Wed, 7 Dec 2022 22:50:49 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20session=E5=9F=BA=E6=9C=AC=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=E5=8F=8A=E5=8D=95=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config-template.py | 2 ++ pkg/openai/session.py | 3 ++- tests/test_session_console.py | 16 ++++++++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 tests/test_session_console.py diff --git a/config-template.py b/config-template.py index c368ccbb..8d13d4dd 100644 --- a/config-template.py +++ b/config-template.py @@ -24,3 +24,5 @@ completion_api_params = { "frequency_penalty": 0.4, "presence_penalty": 0.3, } + +session_expire_time = 60 * 60 * 24 * 7 diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 5b5dc423..e01acdcf 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -35,6 +35,7 @@ class Session: # 向API请求补全 response = pkg.openai.manager.get_inst().request_completion(self.prompt, self.user_name+':') + # print(response) # 处理回复 res_test = response["choices"][0]["text"] res_ans = res_test @@ -45,7 +46,7 @@ class Session: del (res_ans_spt[0]) res_ans = '\n\n'.join(res_ans_spt) - self.prompt += "\n" + self.bot_name + ":{}".format(res_ans) + self.prompt += "{}".format(res_ans) + '\n' return res_ans def persistence(self): diff --git a/tests/test_session_console.py b/tests/test_session_console.py new file mode 100644 index 00000000..e0446ee8 --- /dev/null +++ b/tests/test_session_console.py @@ -0,0 +1,16 @@ +import config +import unittest +import pkg.openai.session +import pkg.openai.manager + + +class TestOpenAISession(unittest.TestCase): + def test_session_console(self): + interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key'], config.completion_api_params) + + session = pkg.openai.session.Session('test') + print(session.append('你好')) + print("#{}#".format(session.prompt)) + + print(session.append('你叫什么名字')) + print("#{}#".format(session.prompt))