Merge pull request #19 from RockChinQ/default-prompt

feat: 支持预设指令
This commit is contained in:
Rock Chin
2022-12-12 17:22:23 +08:00
committed by GitHub
4 changed files with 49 additions and 36 deletions

View File

@@ -1,5 +1,4 @@
# 配置文件: 注释里标[必需]的参数必须修改, 其他参数根据需要修改, 但请勿删除
import logging
# [必需] Mirai的配置
@@ -25,12 +24,17 @@ openai_config = {
admin_qq = 0
# 敏感词过滤开关,以同样数量的*代替敏感词回复
# 开启后可能会降低机器人的回复速度
# 请在sensitive.json中添加敏感词
sensitive_word_filter = True
# 每个会话的预设信息,影响所有会话,无视指令重置
# 可以通过这个字段指定某些情况的回复,可直接用自然语言描述指令
# 例如: 如果我之后想获取帮助,请你说“输入!help获取帮助”
# 可参考 https://github.com/PlexPt/awesome-chatgpt-prompts-zh
default_prompt = ""
# OpenAI的completion API的参数
# 不了解的话请不要修改,具体请查看OpenAI的文档
# 具体请查看OpenAI的文档
completion_api_params = {
"model": "text-davinci-003",
"temperature": 0.6, # 数值越低得到的回答越理性,取值范围[0, 1]

View File

@@ -51,13 +51,19 @@ def dump_session(session_name: str):
del sessions[session_name]
# 从配置文件获取会话预设信息
def get_default_prompt():
return "You:{}\nBot:好的\n".format(config.default_prompt) if hasattr(config, 'default_prompt') and \
config.default_prompt != "" else ''
# 通用的OpenAI API交互session
# session内部保留了对话的上下文
# 收到用户消息后将上下文提交给OpenAI API生成回复
class Session:
name = ''
prompt = ''
prompt = get_default_prompt()
user_name = 'You'
bot_name = 'Bot'
@@ -100,12 +106,13 @@ class Session:
def append(self, text: str) -> str:
self.last_interact_timestamp = int(time.time())
max_length = config.prompt_submit_length if config.prompt_submit_length is not None else 1024
max_rounds = config.prompt_submit_round_amount if hasattr(config, 'prompt_submit_round_amount') else 7
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
# 向API请求补全
response = pkg.openai.manager.get_inst().request_completion(self.cut_out(self.prompt + self.user_name + ':' +
text + '\n' + self.bot_name + ':',
7, max_length), self.user_name + ':')
max_rounds, max_length), self.user_name + ':')
self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':'
# print(response)
@@ -154,7 +161,7 @@ class Session:
# 持久化session
def persistence(self):
if self.prompt == '':
if self.prompt == get_default_prompt():
return
db_inst = pkg.database.manager.get_inst()
@@ -169,14 +176,14 @@ class Session:
# 重置session
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True):
if self.prompt != '':
if self.prompt != get_default_prompt():
self.persistence()
if explicit:
pkg.database.manager.get_inst().explicit_close_session(self.name, self.create_timestamp)
if expired:
pkg.database.manager.get_inst().set_session_expired(self.name, self.create_timestamp)
self.prompt = ''
self.prompt = get_default_prompt()
self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time())
self.just_switched_to_exist_session = False

View File

@@ -190,20 +190,21 @@ class QQBotManager:
else:
processing.append("person_{}".format(event.sender.id))
# 超时则重试,重试超过次数则放弃
failed = 0
for i in range(self.retry):
try:
reply = self.process_message('person', event.sender.id, str(event.message_chain))
break
except FunctionTimedOut:
failed += 1
continue
try:
# 超时则重试,重试超过次数则放弃
failed = 0
for i in range(self.retry):
try:
reply = self.process_message('person', event.sender.id, str(event.message_chain))
break
except FunctionTimedOut:
failed += 1
continue
if failed == self.retry:
reply = "[bot]err:请求超时"
processing.remove("person_{}".format(event.sender.id))
if failed == self.retry:
reply = "[bot]err:请求超时"
finally:
processing.remove("person_{}".format(event.sender.id))
if reply != '':
return await self.bot.send(event, reply)
@@ -225,20 +226,21 @@ class QQBotManager:
processing.append("group_{}".format(event.sender.id))
# 超时则重试,重试超过次数则放弃
failed = 0
for i in range(self.retry):
try:
reply = self.process_message('group', event.group.id, str(event.message_chain).strip())
break
except FunctionTimedOut:
failed += 1
continue
try:
# 超时则重试,重试超过次数则放弃
failed = 0
for i in range(self.retry):
try:
reply = self.process_message('group', event.group.id, str(event.message_chain).strip())
break
except FunctionTimedOut:
failed += 1
continue
if failed == self.retry:
reply = "err:请求超时"
processing.remove("group_{}".format(event.sender.id))
if failed == self.retry:
reply = "err:请求超时"
finally:
processing.remove("group_{}".format(event.sender.id))
if reply != '':
return await self.bot.send(event, reply)

View File

@@ -33,7 +33,7 @@
"中华民国",
"pornhub",
"Pornhub",
"youporn",
"[Yy]ou[Pp]orn",
"porn",
"Porn",
"[Xx][Vv]ideos",