mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-13 09:16:04 +00:00
refactor: 修改引入风格
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
import openai
|
||||
from openai.types.chat import chat_completion_message
|
||||
import json
|
||||
import logging
|
||||
|
||||
from .model import RequestBase
|
||||
import openai
|
||||
from openai.types.chat import chat_completion_message
|
||||
|
||||
from ..funcmgr import get_func_schema_list, execute_function, get_func, get_func_schema, ContentFunctionNotFoundError
|
||||
from .model import RequestBase
|
||||
from .. import funcmgr
|
||||
|
||||
|
||||
class ChatCompletionRequest(RequestBase):
|
||||
@@ -81,7 +81,7 @@ class ChatCompletionRequest(RequestBase):
|
||||
"messages": self.messages,
|
||||
}
|
||||
|
||||
funcs = get_func_schema_list()
|
||||
funcs = funcmgr.get_func_schema_list()
|
||||
|
||||
if len(funcs) > 0:
|
||||
args['functions'] = funcs
|
||||
@@ -171,7 +171,7 @@ class ChatCompletionRequest(RequestBase):
|
||||
# 若不是json格式的异常处理
|
||||
except json.decoder.JSONDecodeError:
|
||||
# 获取函数的参数列表
|
||||
func_schema = get_func_schema(func_name)
|
||||
func_schema = funcmgr.get_func_schema(func_name)
|
||||
|
||||
arguments = {
|
||||
func_schema['parameters']['required'][0]: cp_pending_func_call.arguments
|
||||
@@ -182,7 +182,7 @@ class ChatCompletionRequest(RequestBase):
|
||||
# 执行函数调用
|
||||
ret = ""
|
||||
try:
|
||||
ret = execute_function(func_name, arguments)
|
||||
ret = funcmgr.execute_function(func_name, arguments)
|
||||
|
||||
logging.info("函数执行完成。")
|
||||
except Exception as e:
|
||||
@@ -216,6 +216,5 @@ class ChatCompletionRequest(RequestBase):
|
||||
}
|
||||
}
|
||||
|
||||
except ContentFunctionNotFoundError:
|
||||
except funcmgr.ContentFunctionNotFoundError:
|
||||
raise Exception("没有找到函数: {}".format(func_name))
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import openai
|
||||
from openai.types import completion, completion_choice
|
||||
|
||||
from .model import RequestBase
|
||||
from . import model
|
||||
|
||||
|
||||
class CompletionRequest(RequestBase):
|
||||
class CompletionRequest(model.RequestBase):
|
||||
"""调用Completion接口的请求类。
|
||||
|
||||
调用方可以一直next completion直到finish_reason为stop。
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
# 定义不同接口请求的模型
|
||||
import threading
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import openai
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# 封装了function calling的一些支持函数
|
||||
import logging
|
||||
|
||||
|
||||
from pkg.plugin import host
|
||||
from ..plugin import host
|
||||
|
||||
|
||||
class ContentFunctionNotFoundError(Exception):
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
import pkg.plugin.host as plugin_host
|
||||
import pkg.plugin.models as plugin_models
|
||||
from ..plugin import host as plugin_host
|
||||
from ..plugin import models as plugin_models
|
||||
|
||||
|
||||
class KeysManager:
|
||||
|
||||
@@ -2,12 +2,11 @@ import logging
|
||||
|
||||
import openai
|
||||
|
||||
import pkg.openai.keymgr
|
||||
import pkg.utils.context
|
||||
import pkg.audit.gatherer
|
||||
from pkg.openai.modelmgr import select_request_cls
|
||||
|
||||
from pkg.openai.api.model import RequestBase
|
||||
from ..openai import keymgr
|
||||
from ..utils import context
|
||||
from ..audit import gatherer
|
||||
from ..openai import modelmgr
|
||||
from ..openai.api import model as api_model
|
||||
|
||||
|
||||
class OpenAIInteract:
|
||||
@@ -16,9 +15,9 @@ class OpenAIInteract:
|
||||
将文字接口和图片接口封装供调用方使用
|
||||
"""
|
||||
|
||||
key_mgr: pkg.openai.keymgr.KeysManager = None
|
||||
key_mgr: keymgr.KeysManager = None
|
||||
|
||||
audit_mgr: pkg.audit.gatherer.DataGatherer = None
|
||||
audit_mgr: gatherer.DataGatherer = None
|
||||
|
||||
default_image_api_params = {
|
||||
"size": "256x256",
|
||||
@@ -28,8 +27,8 @@ class OpenAIInteract:
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
|
||||
self.key_mgr = pkg.openai.keymgr.KeysManager(api_key)
|
||||
self.audit_mgr = pkg.audit.gatherer.DataGatherer()
|
||||
self.key_mgr = keymgr.KeysManager(api_key)
|
||||
self.audit_mgr = gatherer.DataGatherer()
|
||||
|
||||
# logging.info("文字总使用量:%d", self.audit_mgr.get_total_text_length())
|
||||
|
||||
@@ -37,22 +36,22 @@ class OpenAIInteract:
|
||||
api_key=self.key_mgr.get_using_key()
|
||||
)
|
||||
|
||||
pkg.utils.context.set_openai_manager(self)
|
||||
context.set_openai_manager(self)
|
||||
|
||||
def request_completion(self, messages: list):
|
||||
"""请求补全接口回复=
|
||||
"""
|
||||
# 选择接口请求类
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
|
||||
request: RequestBase
|
||||
request: api_model.RequestBase
|
||||
|
||||
model: str = config.completion_api_params['model']
|
||||
|
||||
cp_parmas = config.completion_api_params.copy()
|
||||
del cp_parmas['model']
|
||||
|
||||
request = select_request_cls(self.client, model, messages, cp_parmas)
|
||||
request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas)
|
||||
|
||||
# 请求接口
|
||||
for resp in request:
|
||||
@@ -74,7 +73,7 @@ class OpenAIInteract:
|
||||
Returns:
|
||||
dict: 响应
|
||||
"""
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
params = config.image_api_params
|
||||
|
||||
response = openai.Image.create(
|
||||
|
||||
@@ -8,9 +8,9 @@ Completion - text-davinci-003 等模型
|
||||
import tiktoken
|
||||
import openai
|
||||
|
||||
from pkg.openai.api.model import RequestBase
|
||||
from pkg.openai.api.completion import CompletionRequest
|
||||
from pkg.openai.api.chat_completion import ChatCompletionRequest
|
||||
from ..openai.api import model as api_model
|
||||
from ..openai.api import completion as api_completion
|
||||
from ..openai.api import chat_completion as api_chat_completion
|
||||
|
||||
COMPLETION_MODELS = {
|
||||
"text-davinci-003", # legacy
|
||||
@@ -60,11 +60,11 @@ IMAGE_MODELS = {
|
||||
}
|
||||
|
||||
|
||||
def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> RequestBase:
|
||||
def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> api_model.RequestBase:
|
||||
if model_name in CHAT_COMPLETION_MODELS:
|
||||
return ChatCompletionRequest(client, model_name, messages, **args)
|
||||
return api_chat_completion.ChatCompletionRequest(client, model_name, messages, **args)
|
||||
elif model_name in COMPLETION_MODELS:
|
||||
return CompletionRequest(client, model_name, messages, **args)
|
||||
return api_completion.CompletionRequest(client, model_name, messages, **args)
|
||||
raise ValueError("不支持模型[{}],请检查配置文件".format(model_name))
|
||||
|
||||
|
||||
|
||||
@@ -8,15 +8,13 @@ import threading
|
||||
import time
|
||||
import json
|
||||
|
||||
import pkg.openai.manager
|
||||
import pkg.openai.modelmgr
|
||||
import pkg.database.manager
|
||||
import pkg.utils.context
|
||||
from ..openai import manager as openai_manager
|
||||
from ..openai import modelmgr as openai_modelmgr
|
||||
from ..database import manager as database_manager
|
||||
from ..utils import context as context
|
||||
|
||||
import pkg.plugin.host as plugin_host
|
||||
import pkg.plugin.models as plugin_models
|
||||
|
||||
from pkg.openai.modelmgr import count_tokens
|
||||
from ..plugin import host as plugin_host
|
||||
from ..plugin import models as plugin_models
|
||||
|
||||
# 运行时保存的所有session
|
||||
sessions = {}
|
||||
@@ -38,7 +36,7 @@ def reset_session_prompt(session_name, prompt):
|
||||
f.write(prompt)
|
||||
f.close()
|
||||
# 生成新数据
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
prompt = [
|
||||
{
|
||||
'role': 'system',
|
||||
@@ -61,7 +59,7 @@ def load_sessions():
|
||||
|
||||
global sessions
|
||||
|
||||
db_inst = pkg.utils.context.get_database_manager()
|
||||
db_inst = context.get_database_manager()
|
||||
|
||||
session_data = db_inst.load_valid_sessions()
|
||||
|
||||
@@ -172,7 +170,7 @@ class Session:
|
||||
if self.create_timestamp != create_timestamp or self not in sessions.values():
|
||||
return
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
if int(time.time()) - self.last_interact_timestamp > config.session_expire_time:
|
||||
logging.info('session {} 已过期'.format(self.name))
|
||||
|
||||
@@ -182,7 +180,7 @@ class Session:
|
||||
'session': self,
|
||||
'session_expire_time': config.session_expire_time
|
||||
}
|
||||
event = pkg.plugin.host.emit(plugin_models.SessionExpired, **args)
|
||||
event = plugin_host.emit(plugin_models.SessionExpired, **args)
|
||||
if event.is_prevented_default():
|
||||
return
|
||||
|
||||
@@ -214,11 +212,11 @@ class Session:
|
||||
'default_prompt': self.default_prompt,
|
||||
}
|
||||
|
||||
event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args)
|
||||
event = plugin_host.emit(plugin_models.SessionFirstMessageReceived, **args)
|
||||
if event.is_prevented_default():
|
||||
return None, None, None
|
||||
|
||||
config = pkg.utils.context.get_config()
|
||||
config = context.get_config()
|
||||
max_length = config.prompt_submit_length
|
||||
|
||||
local_default_prompt = self.default_prompt.copy()
|
||||
@@ -232,7 +230,7 @@ class Session:
|
||||
'text_message': text,
|
||||
}
|
||||
|
||||
event = pkg.plugin.host.emit(plugin_models.PromptPreProcessing, **args)
|
||||
event = plugin_host.emit(plugin_models.PromptPreProcessing, **args)
|
||||
|
||||
if event.get_return_value('default_prompt') is not None:
|
||||
local_default_prompt = event.get_return_value('default_prompt')
|
||||
@@ -256,14 +254,14 @@ class Session:
|
||||
funcs = []
|
||||
|
||||
trace_func_calls = config.trace_function_calls
|
||||
botmgr = pkg.utils.context.get_qqbot_manager()
|
||||
botmgr = context.get_qqbot_manager()
|
||||
|
||||
session_name_spt: list[str] = self.name.split("_")
|
||||
|
||||
pending_res_text = ""
|
||||
|
||||
# TODO 对不起,我知道这样非常非常屎山,但我之后会重构的
|
||||
for resp in pkg.utils.context.get_openai_manager().request_completion(prompts):
|
||||
for resp in context.get_openai_manager().request_completion(prompts):
|
||||
|
||||
if pending_res_text != "":
|
||||
botmgr.adapter.send_message(
|
||||
@@ -325,7 +323,6 @@ class Session:
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
# 向API请求补全
|
||||
# message, total_token = pkg.utils.context.get_openai_manager().request_completion(
|
||||
# prompts,
|
||||
@@ -383,13 +380,13 @@ class Session:
|
||||
# 包装目前的对话回合内容
|
||||
changable_prompts = []
|
||||
|
||||
use_model = pkg.utils.context.get_config().completion_api_params['model']
|
||||
use_model = context.get_config().completion_api_params['model']
|
||||
|
||||
ptr = len(prompt) - 1
|
||||
|
||||
# 直接从后向前扫描拼接,不管是否是整回合
|
||||
while ptr >= 0:
|
||||
if count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
|
||||
if openai_modelmgr.count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
|
||||
break
|
||||
|
||||
changable_prompts.insert(0, prompt[ptr])
|
||||
@@ -410,14 +407,14 @@ class Session:
|
||||
|
||||
logging.debug("cut_out: {}".format(json.dumps(result_prompt, ensure_ascii=False, indent=4)))
|
||||
|
||||
return result_prompt, count_tokens(changable_prompts, use_model)
|
||||
return result_prompt, openai_modelmgr.count_tokens(changable_prompts, use_model)
|
||||
|
||||
# 持久化session
|
||||
def persistence(self):
|
||||
if self.prompt == self.get_default_prompt():
|
||||
return
|
||||
|
||||
db_inst = pkg.utils.context.get_database_manager()
|
||||
db_inst = context.get_database_manager()
|
||||
|
||||
name_spt = self.name.split('_')
|
||||
|
||||
@@ -439,12 +436,12 @@ class Session:
|
||||
}
|
||||
|
||||
# 此事件不支持阻止默认行为
|
||||
_ = pkg.plugin.host.emit(plugin_models.SessionExplicitReset, **args)
|
||||
_ = plugin_host.emit(plugin_models.SessionExplicitReset, **args)
|
||||
|
||||
pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
|
||||
context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
|
||||
|
||||
if expired:
|
||||
pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
|
||||
context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
|
||||
|
||||
if not persist: # 不要求保持default prompt
|
||||
self.default_prompt = self.get_default_prompt(use_prompt)
|
||||
@@ -461,11 +458,11 @@ class Session:
|
||||
|
||||
# 将本session的数据库状态设置为on_going
|
||||
def set_ongoing(self):
|
||||
pkg.utils.context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp)
|
||||
context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp)
|
||||
|
||||
# 切换到上一个session
|
||||
def last_session(self):
|
||||
last_one = pkg.utils.context.get_database_manager().last_session(self.name, self.last_interact_timestamp)
|
||||
last_one = context.get_database_manager().last_session(self.name, self.last_interact_timestamp)
|
||||
if last_one is None:
|
||||
return None
|
||||
else:
|
||||
@@ -486,7 +483,7 @@ class Session:
|
||||
|
||||
# 切换到下一个session
|
||||
def next_session(self):
|
||||
next_one = pkg.utils.context.get_database_manager().next_session(self.name, self.last_interact_timestamp)
|
||||
next_one = context.get_database_manager().next_session(self.name, self.last_interact_timestamp)
|
||||
if next_one is None:
|
||||
return None
|
||||
else:
|
||||
@@ -506,13 +503,13 @@ class Session:
|
||||
return self
|
||||
|
||||
def list_history(self, capacity: int = 10, page: int = 0):
|
||||
return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page)
|
||||
return context.get_database_manager().list_history(self.name, capacity, page)
|
||||
|
||||
def delete_history(self, index: int) -> bool:
|
||||
return pkg.utils.context.get_database_manager().delete_history(self.name, index)
|
||||
return context.get_database_manager().delete_history(self.name, index)
|
||||
|
||||
def delete_all_history(self) -> bool:
|
||||
return pkg.utils.context.get_database_manager().delete_all_history(self.name)
|
||||
return context.get_database_manager().delete_all_history(self.name)
|
||||
|
||||
def draw_image(self, prompt: str):
|
||||
return pkg.utils.context.get_openai_manager().request_image(prompt)
|
||||
return context.get_openai_manager().request_image(prompt)
|
||||
|
||||
Reference in New Issue
Block a user