refactor: 修改引入风格

This commit is contained in:
RockChinQ
2023-11-13 21:59:23 +08:00
parent e3b280758c
commit 665de5dc43
47 changed files with 324 additions and 364 deletions

View File

@@ -8,15 +8,13 @@ import threading
import time
import json
import pkg.openai.manager
import pkg.openai.modelmgr
import pkg.database.manager
import pkg.utils.context
from ..openai import manager as openai_manager
from ..openai import modelmgr as openai_modelmgr
from ..database import manager as database_manager
from ..utils import context as context
import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models
from pkg.openai.modelmgr import count_tokens
from ..plugin import host as plugin_host
from ..plugin import models as plugin_models
# 运行时保存的所有session
sessions = {}
@@ -38,7 +36,7 @@ def reset_session_prompt(session_name, prompt):
f.write(prompt)
f.close()
# 生成新数据
config = pkg.utils.context.get_config()
config = context.get_config()
prompt = [
{
'role': 'system',
@@ -61,7 +59,7 @@ def load_sessions():
global sessions
db_inst = pkg.utils.context.get_database_manager()
db_inst = context.get_database_manager()
session_data = db_inst.load_valid_sessions()
@@ -172,7 +170,7 @@ class Session:
if self.create_timestamp != create_timestamp or self not in sessions.values():
return
config = pkg.utils.context.get_config()
config = context.get_config()
if int(time.time()) - self.last_interact_timestamp > config.session_expire_time:
logging.info('session {} 已过期'.format(self.name))
@@ -182,7 +180,7 @@ class Session:
'session': self,
'session_expire_time': config.session_expire_time
}
event = pkg.plugin.host.emit(plugin_models.SessionExpired, **args)
event = plugin_host.emit(plugin_models.SessionExpired, **args)
if event.is_prevented_default():
return
@@ -214,11 +212,11 @@ class Session:
'default_prompt': self.default_prompt,
}
event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args)
event = plugin_host.emit(plugin_models.SessionFirstMessageReceived, **args)
if event.is_prevented_default():
return None, None, None
config = pkg.utils.context.get_config()
config = context.get_config()
max_length = config.prompt_submit_length
local_default_prompt = self.default_prompt.copy()
@@ -232,7 +230,7 @@ class Session:
'text_message': text,
}
event = pkg.plugin.host.emit(plugin_models.PromptPreProcessing, **args)
event = plugin_host.emit(plugin_models.PromptPreProcessing, **args)
if event.get_return_value('default_prompt') is not None:
local_default_prompt = event.get_return_value('default_prompt')
@@ -256,14 +254,14 @@ class Session:
funcs = []
trace_func_calls = config.trace_function_calls
botmgr = pkg.utils.context.get_qqbot_manager()
botmgr = context.get_qqbot_manager()
session_name_spt: list[str] = self.name.split("_")
pending_res_text = ""
# TODO 对不起,我知道这样非常非常屎山,但我之后会重构的
for resp in pkg.utils.context.get_openai_manager().request_completion(prompts):
for resp in context.get_openai_manager().request_completion(prompts):
if pending_res_text != "":
botmgr.adapter.send_message(
@@ -325,7 +323,6 @@ class Session:
)
pass
# 向API请求补全
# message, total_token = pkg.utils.context.get_openai_manager().request_completion(
# prompts,
@@ -383,13 +380,13 @@ class Session:
# 包装目前的对话回合内容
changable_prompts = []
use_model = pkg.utils.context.get_config().completion_api_params['model']
use_model = context.get_config().completion_api_params['model']
ptr = len(prompt) - 1
# 直接从后向前扫描拼接,不管是否是整回合
while ptr >= 0:
if count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
if openai_modelmgr.count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
break
changable_prompts.insert(0, prompt[ptr])
@@ -410,14 +407,14 @@ class Session:
logging.debug("cut_out: {}".format(json.dumps(result_prompt, ensure_ascii=False, indent=4)))
return result_prompt, count_tokens(changable_prompts, use_model)
return result_prompt, openai_modelmgr.count_tokens(changable_prompts, use_model)
# 持久化session
def persistence(self):
if self.prompt == self.get_default_prompt():
return
db_inst = pkg.utils.context.get_database_manager()
db_inst = context.get_database_manager()
name_spt = self.name.split('_')
@@ -439,12 +436,12 @@ class Session:
}
# 此事件不支持阻止默认行为
_ = pkg.plugin.host.emit(plugin_models.SessionExplicitReset, **args)
_ = plugin_host.emit(plugin_models.SessionExplicitReset, **args)
pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
if expired:
pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
if not persist: # 不要求保持default prompt
self.default_prompt = self.get_default_prompt(use_prompt)
@@ -461,11 +458,11 @@ class Session:
# 将本session的数据库状态设置为on_going
def set_ongoing(self):
pkg.utils.context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp)
context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp)
# 切换到上一个session
def last_session(self):
last_one = pkg.utils.context.get_database_manager().last_session(self.name, self.last_interact_timestamp)
last_one = context.get_database_manager().last_session(self.name, self.last_interact_timestamp)
if last_one is None:
return None
else:
@@ -486,7 +483,7 @@ class Session:
# 切换到下一个session
def next_session(self):
next_one = pkg.utils.context.get_database_manager().next_session(self.name, self.last_interact_timestamp)
next_one = context.get_database_manager().next_session(self.name, self.last_interact_timestamp)
if next_one is None:
return None
else:
@@ -506,13 +503,13 @@ class Session:
return self
def list_history(self, capacity: int = 10, page: int = 0):
return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page)
return context.get_database_manager().list_history(self.name, capacity, page)
def delete_history(self, index: int) -> bool:
return pkg.utils.context.get_database_manager().delete_history(self.name, index)
return context.get_database_manager().delete_history(self.name, index)
def delete_all_history(self) -> bool:
return pkg.utils.context.get_database_manager().delete_all_history(self.name)
return context.get_database_manager().delete_all_history(self.name)
def draw_image(self, prompt: str):
return pkg.utils.context.get_openai_manager().request_image(prompt)
return context.get_openai_manager().request_image(prompt)