mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat(contentPlugin): 完成基本的内容函数调用功能
This commit is contained in:
0
pkg/openai/api/__init__.py
Normal file
0
pkg/openai/api/__init__.py
Normal file
182
pkg/openai/api/chat_completion.py
Normal file
182
pkg/openai/api/chat_completion.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import openai
|
||||
import json
|
||||
|
||||
from .model import RequestBase
|
||||
|
||||
from ..funcmgr import get_func_schema_list, execute_function, get_func, get_func_schema, ContentFunctionNotFoundError
|
||||
|
||||
|
||||
class ChatCompletionRequest(RequestBase):
|
||||
"""调用ChatCompletion接口的请求类。
|
||||
|
||||
此类保证每一次返回的角色为assistant的信息的finish_reason一定为stop。
|
||||
若有函数调用响应,本类的返回瀑布是:函数调用请求->函数调用结果->...->assistant的信息->stop。
|
||||
"""
|
||||
model: str
|
||||
messages: list[dict[str, str]]
|
||||
kwargs: dict
|
||||
|
||||
stopped: bool = False
|
||||
|
||||
pending_func_call: dict = None
|
||||
|
||||
pending_msg: str
|
||||
|
||||
def flush_pending_msg(self):
|
||||
self.append_message(
|
||||
role="assistant",
|
||||
content=self.pending_msg
|
||||
)
|
||||
self.pending_msg = ""
|
||||
|
||||
def append_message(self, role: str, content: str, name: str=None):
|
||||
msg = {
|
||||
"role": role,
|
||||
"content": content
|
||||
}
|
||||
|
||||
if name is not None:
|
||||
msg['name'] = name
|
||||
|
||||
self.messages.append(msg)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[dict[str, str]],
|
||||
**kwargs
|
||||
):
|
||||
self.model = model
|
||||
self.messages = messages.copy()
|
||||
|
||||
self.kwargs = kwargs
|
||||
|
||||
self.req_func = openai.ChatCompletion.acreate
|
||||
|
||||
self.pending_func_call = None
|
||||
|
||||
self.stopped = False
|
||||
|
||||
self.pending_msg = ""
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> dict:
|
||||
if self.stopped:
|
||||
raise StopIteration()
|
||||
|
||||
if self.pending_func_call is None: # 没有待处理的函数调用请求
|
||||
|
||||
resp = self._req(
|
||||
model=self.model,
|
||||
messages=self.messages,
|
||||
functions=get_func_schema_list(),
|
||||
**self.kwargs
|
||||
)
|
||||
|
||||
choice0 = resp["choices"][0]
|
||||
|
||||
# 如果不是函数调用,且finish_reason为stop,则停止迭代
|
||||
if 'function_call' not in choice0['message'] and choice0["finish_reason"] == "stop":
|
||||
self.stopped = True
|
||||
|
||||
if 'function_call' in choice0['message']:
|
||||
self.pending_func_call = choice0['message']['function_call']
|
||||
|
||||
self.append_message(
|
||||
role="assistant",
|
||||
content="function call: "+json.dumps(self.pending_func_call)
|
||||
)
|
||||
|
||||
return {
|
||||
"id": resp["id"],
|
||||
"choices": [
|
||||
{
|
||||
"index": choice0["index"],
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"type": "function_call",
|
||||
"content": None,
|
||||
"function_call": choice0['message']['function_call']
|
||||
},
|
||||
"finish_reason": "function_call"
|
||||
}
|
||||
],
|
||||
"usage": resp["usage"]
|
||||
}
|
||||
else:
|
||||
|
||||
# self.pending_msg += choice0['message']['content']
|
||||
# 普通回复一定处于最后方,故不用再追加进内部messages
|
||||
|
||||
return {
|
||||
"id": resp["id"],
|
||||
"choices": [
|
||||
{
|
||||
"index": choice0["index"],
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"type": "text",
|
||||
"content": choice0['message']['content']
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": resp["usage"]
|
||||
}
|
||||
else: # 处理函数调用请求
|
||||
|
||||
cp_pending_func_call = self.pending_func_call.copy()
|
||||
|
||||
self.pending_func_call = None
|
||||
|
||||
func_name = cp_pending_func_call['name']
|
||||
arguments = {}
|
||||
|
||||
try:
|
||||
|
||||
try:
|
||||
arguments = json.loads(cp_pending_func_call['arguments'])
|
||||
# 若不是json格式的异常处理
|
||||
except json.decoder.JSONDecodeError:
|
||||
# 获取函数的参数列表
|
||||
func_schema = get_func_schema(func_name)
|
||||
|
||||
arguments = {
|
||||
func_schema['parameters']['required'][0]: cp_pending_func_call['arguments']
|
||||
}
|
||||
|
||||
# 执行函数调用
|
||||
ret = execute_function(func_name, arguments)
|
||||
|
||||
self.append_message(
|
||||
role="function",
|
||||
content=json.dumps(ret),
|
||||
name=func_name
|
||||
)
|
||||
|
||||
return {
|
||||
"id": -1,
|
||||
"choices": [
|
||||
{
|
||||
"index": -1,
|
||||
"message": {
|
||||
"role": "function",
|
||||
"type": "function_return",
|
||||
"function_name": func_name,
|
||||
"content": json.dumps(ret)
|
||||
},
|
||||
"finish_reason": "function_return"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
except ContentFunctionNotFoundError:
|
||||
raise Exception("没有找到函数: {}".format(func_name))
|
||||
|
||||
111
pkg/openai/api/completion.py
Normal file
111
pkg/openai/api/completion.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import openai
|
||||
|
||||
from .model import RequestBase
|
||||
|
||||
|
||||
class CompletionRequest(RequestBase):
|
||||
"""调用Completion接口的请求类。
|
||||
|
||||
调用方可以一直next completion直到finish_reason为stop。
|
||||
"""
|
||||
|
||||
model: str
|
||||
prompt: str
|
||||
kwargs: dict
|
||||
|
||||
stopped: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[dict[str, str]],
|
||||
**kwargs
|
||||
):
|
||||
self.model = model
|
||||
self.prompt = ""
|
||||
|
||||
for message in messages:
|
||||
self.prompt += message["role"] + ": " + message["content"] + "\n"
|
||||
|
||||
self.prompt += "assistant: "
|
||||
|
||||
self.kwargs = kwargs
|
||||
|
||||
self.req_func = openai.Completion.acreate
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> dict:
|
||||
"""调用Completion接口,返回生成的文本
|
||||
|
||||
{
|
||||
"id": "id",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"type": "text",
|
||||
"content": "message"
|
||||
},
|
||||
"finish_reason": "reason"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
if self.stopped:
|
||||
raise StopIteration()
|
||||
|
||||
resp = self._req(
|
||||
model=self.model,
|
||||
prompt=self.prompt,
|
||||
**self.kwargs
|
||||
)
|
||||
|
||||
if resp["choices"][0]["finish_reason"] == "stop":
|
||||
self.stopped = True
|
||||
|
||||
choice0 = resp["choices"][0]
|
||||
|
||||
self.prompt += choice0["text"]
|
||||
|
||||
return {
|
||||
"id": resp["id"],
|
||||
"choices": [
|
||||
{
|
||||
"index": choice0["index"],
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"type": "text",
|
||||
"content": choice0["text"]
|
||||
},
|
||||
"finish_reason": choice0["finish_reason"]
|
||||
}
|
||||
],
|
||||
"usage": resp["usage"]
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
openai.api_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
for resp in CompletionRequest(
|
||||
model="text-davinci-003",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, who are you?"
|
||||
}
|
||||
]
|
||||
):
|
||||
print(resp)
|
||||
if resp["choices"][0]["finish_reason"] == "stop":
|
||||
break
|
||||
42
pkg/openai/api/model.py
Normal file
42
pkg/openai/api/model.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# 定义不同接口请求的模型
|
||||
import threading
|
||||
import asyncio
|
||||
|
||||
import openai
|
||||
|
||||
|
||||
class RequestBase:
|
||||
|
||||
req_func: callable
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def _req(self, **kwargs):
|
||||
"""处理代理问题"""
|
||||
|
||||
ret: dict = {}
|
||||
|
||||
async def awrapper(**kwargs):
|
||||
nonlocal ret
|
||||
|
||||
ret = await self.req_func(**kwargs)
|
||||
return ret
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
thr = threading.Thread(
|
||||
target=loop.run_until_complete,
|
||||
args=(awrapper(**kwargs),)
|
||||
)
|
||||
|
||||
thr.start()
|
||||
thr.join()
|
||||
|
||||
return ret
|
||||
|
||||
def __iter__(self):
|
||||
raise self
|
||||
|
||||
def __next__(self):
|
||||
raise NotImplementedError
|
||||
37
pkg/openai/funcmgr.py
Normal file
37
pkg/openai/funcmgr.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# 封装了function calling的一些支持函数
|
||||
import logging
|
||||
|
||||
|
||||
from pkg.plugin.host import __callable_functions__, __function_inst_map__
|
||||
|
||||
|
||||
class ContentFunctionNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_func_schema_list() -> list:
|
||||
"""从plugin包中的函数结构中获取并处理成受GPT支持的格式"""
|
||||
|
||||
schemas = __callable_functions__
|
||||
|
||||
return schemas
|
||||
|
||||
def get_func(name: str) -> callable:
|
||||
if name not in __function_inst_map__:
|
||||
raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name))
|
||||
|
||||
return __function_inst_map__[name]
|
||||
|
||||
def get_func_schema(name: str) -> dict:
|
||||
for func in __callable_functions__:
|
||||
if func['name'] == name:
|
||||
return func
|
||||
raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name))
|
||||
|
||||
def execute_function(name: str, kwargs: dict) -> any:
|
||||
"""执行函数调用"""
|
||||
|
||||
logging.debug("executing function: name='{}', kwargs={}".format(name, kwargs))
|
||||
|
||||
func = get_func(name)
|
||||
return func(**kwargs)
|
||||
@@ -5,7 +5,9 @@ import openai
|
||||
import pkg.openai.keymgr
|
||||
import pkg.utils.context
|
||||
import pkg.audit.gatherer
|
||||
from pkg.openai.modelmgr import ModelRequest, create_openai_model_request
|
||||
from pkg.openai.modelmgr import select_request_cls
|
||||
|
||||
from pkg.openai.api.model import RequestBase
|
||||
|
||||
|
||||
class OpenAIInteract:
|
||||
@@ -33,45 +35,58 @@ class OpenAIInteract:
|
||||
|
||||
pkg.utils.context.set_openai_manager(self)
|
||||
|
||||
# 请求OpenAI Completion
|
||||
def request_completion(self, prompts) -> tuple[str, int]:
|
||||
"""请求补全接口回复
|
||||
|
||||
Parameters:
|
||||
prompts (str): 提示语
|
||||
|
||||
Returns:
|
||||
str: 回复
|
||||
def request_completion(self, messages: list):
|
||||
"""请求补全接口回复=
|
||||
"""
|
||||
|
||||
# 选择接口请求类
|
||||
config = pkg.utils.context.get_config()
|
||||
|
||||
# 根据模型选择使用的接口
|
||||
ai: ModelRequest = create_openai_model_request(
|
||||
config.completion_api_params['model'],
|
||||
'user',
|
||||
config.openai_config["http_proxy"] if "http_proxy" in config.openai_config else None
|
||||
)
|
||||
ai.request(
|
||||
prompts,
|
||||
**config.completion_api_params
|
||||
)
|
||||
response = ai.get_response()
|
||||
request: RequestBase
|
||||
|
||||
logging.debug("OpenAI response: %s", response)
|
||||
model: str = config.completion_api_params['model']
|
||||
|
||||
# 记录使用量
|
||||
current_round_token = 0
|
||||
if 'model' in config.completion_api_params:
|
||||
self.audit_mgr.report_text_model_usage(config.completion_api_params['model'],
|
||||
ai.get_total_tokens())
|
||||
current_round_token = ai.get_total_tokens()
|
||||
elif 'engine' in config.completion_api_params:
|
||||
self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'],
|
||||
response['usage']['total_tokens'])
|
||||
current_round_token = response['usage']['total_tokens']
|
||||
cp_parmas = config.completion_api_params.copy()
|
||||
del cp_parmas['model']
|
||||
|
||||
return ai.get_message(), current_round_token
|
||||
request = select_request_cls(model, messages, cp_parmas)
|
||||
|
||||
# 请求接口
|
||||
for resp in request:
|
||||
yield resp
|
||||
|
||||
# 请求OpenAI Completion
|
||||
# def request_completion(self, prompts):
|
||||
# """请求补全接口回复
|
||||
# """
|
||||
|
||||
# config = pkg.utils.context.get_config()
|
||||
|
||||
# # 根据模型选择使用的接口
|
||||
# ai: ModelRequest = create_openai_model_request(
|
||||
# config.completion_api_params['model'],
|
||||
# 'user',
|
||||
# config.openai_config["http_proxy"] if "http_proxy" in config.openai_config else None
|
||||
# )
|
||||
# ai.request(
|
||||
# prompts,
|
||||
# **config.completion_api_params
|
||||
# )
|
||||
# response = ai.get_response()
|
||||
|
||||
# logging.debug("OpenAI response: %s", response)
|
||||
|
||||
# # 记录使用量
|
||||
# current_round_token = 0
|
||||
# if 'model' in config.completion_api_params:
|
||||
# self.audit_mgr.report_text_model_usage(config.completion_api_params['model'],
|
||||
# ai.get_total_tokens())
|
||||
# current_round_token = ai.get_total_tokens()
|
||||
# elif 'engine' in config.completion_api_params:
|
||||
# self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'],
|
||||
# response['usage']['total_tokens'])
|
||||
# current_round_token = response['usage']['total_tokens']
|
||||
|
||||
# return ai.get_message(), current_round_token
|
||||
|
||||
def request_image(self, prompt) -> dict:
|
||||
"""请求图片接口回复
|
||||
|
||||
@@ -8,6 +8,10 @@ Completion - text-davinci-003 等模型
|
||||
import openai, logging, threading, asyncio
|
||||
import openai.error as aiE
|
||||
|
||||
from pkg.openai.api.model import RequestBase
|
||||
from pkg.openai.api.completion import CompletionRequest
|
||||
from pkg.openai.api.chat_completion import ChatCompletionRequest
|
||||
|
||||
COMPLETION_MODELS = {
|
||||
'text-davinci-003',
|
||||
'text-davinci-002',
|
||||
@@ -39,153 +43,160 @@ IMAGE_MODELS = {
|
||||
}
|
||||
|
||||
|
||||
class ModelRequest:
|
||||
"""模型接口请求父类"""
|
||||
# class ModelRequest:
|
||||
# """模型接口请求父类"""
|
||||
|
||||
can_chat = False
|
||||
runtime: threading.Thread = None
|
||||
ret = {}
|
||||
proxy: str = None
|
||||
request_ready = True
|
||||
error_info: str = "若在没有任何错误的情况下看到这句话,请带着配置文件上报Issues"
|
||||
# can_chat = False
|
||||
# runtime: threading.Thread = None
|
||||
# ret = {}
|
||||
# proxy: str = None
|
||||
# request_ready = True
|
||||
# error_info: str = "若在没有任何错误的情况下看到这句话,请带着配置文件上报Issues"
|
||||
|
||||
def __init__(self, model_name, user_name, request_fun, http_proxy:str = None, time_out = None):
|
||||
self.model_name = model_name
|
||||
self.user_name = user_name
|
||||
self.request_fun = request_fun
|
||||
self.time_out = time_out
|
||||
if http_proxy != None:
|
||||
self.proxy = http_proxy
|
||||
openai.proxy = self.proxy
|
||||
self.request_ready = False
|
||||
# def __init__(self, model_name, user_name, request_fun, http_proxy:str = None, time_out = None):
|
||||
# self.model_name = model_name
|
||||
# self.user_name = user_name
|
||||
# self.request_fun = request_fun
|
||||
# self.time_out = time_out
|
||||
# if http_proxy != None:
|
||||
# self.proxy = http_proxy
|
||||
# openai.proxy = self.proxy
|
||||
# self.request_ready = False
|
||||
|
||||
async def __a_request__(self, **kwargs):
|
||||
"""异步请求"""
|
||||
# async def __a_request__(self, **kwargs):
|
||||
# """异步请求"""
|
||||
|
||||
try:
|
||||
self.ret: dict = await self.request_fun(**kwargs)
|
||||
self.request_ready = True
|
||||
except aiE.APIConnectionError as e:
|
||||
self.error_info = "{}\n请检查网络连接或代理是否正常".format(e)
|
||||
raise ConnectionError(self.error_info)
|
||||
except ValueError as e:
|
||||
self.error_info = "{}\n该错误可能是由于http_proxy格式设置错误引起的"
|
||||
except Exception as e:
|
||||
self.error_info = "{}\n由于请求异常产生的未知错误,请查看日志".format(e)
|
||||
raise type(e)(self.error_info)
|
||||
# try:
|
||||
# self.ret: dict = await self.request_fun(**kwargs)
|
||||
# self.request_ready = True
|
||||
# except aiE.APIConnectionError as e:
|
||||
# self.error_info = "{}\n请检查网络连接或代理是否正常".format(e)
|
||||
# raise ConnectionError(self.error_info)
|
||||
# except ValueError as e:
|
||||
# self.error_info = "{}\n该错误可能是由于http_proxy格式设置错误引起的"
|
||||
# except Exception as e:
|
||||
# self.error_info = "{}\n由于请求异常产生的未知错误,请查看日志".format(e)
|
||||
# raise type(e)(self.error_info)
|
||||
|
||||
def request(self, **kwargs):
|
||||
"""向接口发起请求"""
|
||||
# def request(self, **kwargs):
|
||||
# """向接口发起请求"""
|
||||
|
||||
if self.proxy != None: #异步请求
|
||||
self.request_ready = False
|
||||
loop = asyncio.new_event_loop()
|
||||
self.runtime = threading.Thread(
|
||||
target=loop.run_until_complete,
|
||||
args=(self.__a_request__(**kwargs),)
|
||||
)
|
||||
self.runtime.start()
|
||||
else: #同步请求
|
||||
self.ret = self.request_fun(**kwargs)
|
||||
# if self.proxy != None: #异步请求
|
||||
# self.request_ready = False
|
||||
# loop = asyncio.new_event_loop()
|
||||
# self.runtime = threading.Thread(
|
||||
# target=loop.run_until_complete,
|
||||
# args=(self.__a_request__(**kwargs),)
|
||||
# )
|
||||
# self.runtime.start()
|
||||
# else: #同步请求
|
||||
# self.ret = self.request_fun(**kwargs)
|
||||
|
||||
def __msg_handle__(self, msg):
|
||||
"""将prompt dict转换成接口需要的格式"""
|
||||
return msg
|
||||
# def __msg_handle__(self, msg):
|
||||
# """将prompt dict转换成接口需要的格式"""
|
||||
# return msg
|
||||
|
||||
def ret_handle(self):
|
||||
'''
|
||||
API消息返回处理函数
|
||||
若重写该方法,应检查异步线程状态,或在需要检查处super该方法
|
||||
'''
|
||||
if self.runtime != None and isinstance(self.runtime, threading.Thread):
|
||||
self.runtime.join(self.time_out)
|
||||
if self.request_ready:
|
||||
return
|
||||
raise Exception(self.error_info)
|
||||
# def ret_handle(self):
|
||||
# '''
|
||||
# API消息返回处理函数
|
||||
# 若重写该方法,应检查异步线程状态,或在需要检查处super该方法
|
||||
# '''
|
||||
# if self.runtime != None and isinstance(self.runtime, threading.Thread):
|
||||
# self.runtime.join(self.time_out)
|
||||
# if self.request_ready:
|
||||
# return
|
||||
# raise Exception(self.error_info)
|
||||
|
||||
def get_total_tokens(self):
|
||||
try:
|
||||
return self.ret['usage']['total_tokens']
|
||||
except:
|
||||
return 0
|
||||
# def get_total_tokens(self):
|
||||
# try:
|
||||
# return self.ret['usage']['total_tokens']
|
||||
# except:
|
||||
# return 0
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
# def get_message(self):
|
||||
# return self.message
|
||||
|
||||
def get_response(self):
|
||||
return self.ret
|
||||
# def get_response(self):
|
||||
# return self.ret
|
||||
|
||||
|
||||
class ChatCompletionModel(ModelRequest):
|
||||
"""ChatCompletion接口的请求实现"""
|
||||
# class ChatCompletionModel(ModelRequest):
|
||||
# """ChatCompletion接口的请求实现"""
|
||||
|
||||
Chat_role = ['system', 'user', 'assistant']
|
||||
def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
|
||||
if http_proxy == None:
|
||||
request_fun = openai.ChatCompletion.create
|
||||
else:
|
||||
request_fun = openai.ChatCompletion.acreate
|
||||
self.can_chat = True
|
||||
super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs)
|
||||
# Chat_role = ['system', 'user', 'assistant']
|
||||
# def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
|
||||
# if http_proxy == None:
|
||||
# request_fun = openai.ChatCompletion.create
|
||||
# else:
|
||||
# request_fun = openai.ChatCompletion.acreate
|
||||
# self.can_chat = True
|
||||
# super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs)
|
||||
|
||||
def request(self, prompts, **kwargs):
|
||||
prompts = self.__msg_handle__(prompts)
|
||||
kwargs['messages'] = prompts
|
||||
super().request(**kwargs)
|
||||
self.ret_handle()
|
||||
# def request(self, prompts, **kwargs):
|
||||
# prompts = self.__msg_handle__(prompts)
|
||||
# kwargs['messages'] = prompts
|
||||
# super().request(**kwargs)
|
||||
# self.ret_handle()
|
||||
|
||||
def __msg_handle__(self, msgs):
|
||||
temp_msgs = []
|
||||
# 把msgs拷贝进temp_msgs
|
||||
for msg in msgs:
|
||||
temp_msgs.append(msg.copy())
|
||||
return temp_msgs
|
||||
# def __msg_handle__(self, msgs):
|
||||
# temp_msgs = []
|
||||
# # 把msgs拷贝进temp_msgs
|
||||
# for msg in msgs:
|
||||
# temp_msgs.append(msg.copy())
|
||||
# return temp_msgs
|
||||
|
||||
def get_message(self):
|
||||
return self.ret["choices"][0]["message"]['content'] #需要时直接加载加快请求速度,降低内存消耗
|
||||
# def get_message(self):
|
||||
# return self.ret["choices"][0]["message"]['content'] #需要时直接加载加快请求速度,降低内存消耗
|
||||
|
||||
|
||||
class CompletionModel(ModelRequest):
|
||||
"""Completion接口的请求实现"""
|
||||
# class CompletionModel(ModelRequest):
|
||||
# """Completion接口的请求实现"""
|
||||
|
||||
def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
|
||||
if http_proxy == None:
|
||||
request_fun = openai.Completion.create
|
||||
else:
|
||||
request_fun = openai.Completion.acreate
|
||||
super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs)
|
||||
# def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
|
||||
# if http_proxy == None:
|
||||
# request_fun = openai.Completion.create
|
||||
# else:
|
||||
# request_fun = openai.Completion.acreate
|
||||
# super().__init__(model_name, user_name, request_fun, http_proxy, **kwargs)
|
||||
|
||||
def request(self, prompts, **kwargs):
|
||||
prompts = self.__msg_handle__(prompts)
|
||||
kwargs['prompt'] = prompts
|
||||
super().request(**kwargs)
|
||||
self.ret_handle()
|
||||
# def request(self, prompts, **kwargs):
|
||||
# prompts = self.__msg_handle__(prompts)
|
||||
# kwargs['prompt'] = prompts
|
||||
# super().request(**kwargs)
|
||||
# self.ret_handle()
|
||||
|
||||
def __msg_handle__(self, msgs):
|
||||
prompt = ''
|
||||
for msg in msgs:
|
||||
prompt = prompt + "{}: {}\n".format(msg['role'], msg['content'])
|
||||
# for msg in msgs:
|
||||
# if msg['role'] == 'assistant':
|
||||
# prompt = prompt + "{}\n".format(msg['content'])
|
||||
# else:
|
||||
# prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content'])
|
||||
prompt = prompt + "assistant: "
|
||||
return prompt
|
||||
# def __msg_handle__(self, msgs):
|
||||
# prompt = ''
|
||||
# for msg in msgs:
|
||||
# prompt = prompt + "{}: {}\n".format(msg['role'], msg['content'])
|
||||
# # for msg in msgs:
|
||||
# # if msg['role'] == 'assistant':
|
||||
# # prompt = prompt + "{}\n".format(msg['content'])
|
||||
# # else:
|
||||
# # prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content'])
|
||||
# prompt = prompt + "assistant: "
|
||||
# return prompt
|
||||
|
||||
def get_message(self):
|
||||
return self.ret["choices"][0]["text"]
|
||||
# def get_message(self):
|
||||
# return self.ret["choices"][0]["text"]
|
||||
|
||||
|
||||
def create_openai_model_request(model_name: str, user_name: str = 'user', http_proxy:str = None) -> ModelRequest:
|
||||
"""使用给定的模型名称创建模型请求对象"""
|
||||
# def create_openai_model_request(model_name: str, user_name: str = 'user', http_proxy:str = None) -> ModelRequest:
|
||||
# """使用给定的模型名称创建模型请求对象"""
|
||||
# if model_name in CHAT_COMPLETION_MODELS:
|
||||
# model = ChatCompletionModel(model_name, user_name, http_proxy)
|
||||
# elif model_name in COMPLETION_MODELS:
|
||||
# model = CompletionModel(model_name, user_name, http_proxy)
|
||||
# else :
|
||||
# log = "找不到模型[{}],请检查配置文件".format(model_name)
|
||||
# logging.error(log)
|
||||
# raise IndexError(log)
|
||||
# logging.debug("使用接口[{}]创建模型请求[{}]".format(model.__class__.__name__, model_name))
|
||||
# return model
|
||||
|
||||
def select_request_cls(model_name: str, messages: list, args: dict) -> RequestBase:
|
||||
if model_name in CHAT_COMPLETION_MODELS:
|
||||
model = ChatCompletionModel(model_name, user_name, http_proxy)
|
||||
return ChatCompletionRequest(model_name, messages, **args)
|
||||
elif model_name in COMPLETION_MODELS:
|
||||
model = CompletionModel(model_name, user_name, http_proxy)
|
||||
else :
|
||||
log = "找不到模型[{}],请检查配置文件".format(model_name)
|
||||
logging.error(log)
|
||||
raise IndexError(log)
|
||||
logging.debug("使用接口[{}]创建模型请求[{}]".format(model.__class__.__name__, model_name))
|
||||
return model
|
||||
return CompletionRequest(model_name, messages, **args)
|
||||
raise ValueError("不支持模型[{}],请检查配置文件".format(model_name))
|
||||
79
pkg/openai/sess.py
Normal file
79
pkg/openai/sess.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
|
||||
|
||||
sessions = {}
|
||||
|
||||
|
||||
class SessionOfflineStatus:
|
||||
ON_GOING = "on_going"
|
||||
EXPLICITLY_CLOSED = "explicitly_closed"
|
||||
|
||||
|
||||
def reset_session_prompt(session_name, prompt):
|
||||
pass
|
||||
|
||||
|
||||
def load_sessions():
|
||||
pass
|
||||
|
||||
|
||||
def get_session(session_name: str) -> 'Session':
|
||||
pass
|
||||
|
||||
|
||||
def dump_session(session_name: str):
|
||||
pass
|
||||
|
||||
|
||||
class Session:
|
||||
name: str = ''
|
||||
|
||||
default_prompt: list = []
|
||||
"""会话系统提示语"""
|
||||
|
||||
messages: list = []
|
||||
"""保存消息历史记录"""
|
||||
|
||||
token_counts: list = []
|
||||
"""记录每回合的token数量"""
|
||||
|
||||
create_ts: int = 0
|
||||
"""会话创建时间戳"""
|
||||
|
||||
last_active_ts: int = 0
|
||||
"""会话最后活跃时间戳"""
|
||||
|
||||
just_switched_to_exist_session: bool = False
|
||||
|
||||
response_lock = None
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.default_prompt = self.get_runtime_default_prompt()
|
||||
logging.debug("prompt is: {}".format(self.default_prompt))
|
||||
self.messages = []
|
||||
self.token_counts = []
|
||||
self.create_ts = int(time.time())
|
||||
self.last_active_ts = int(time.time())
|
||||
|
||||
self.response_lock = threading.Lock()
|
||||
|
||||
self.schedule()
|
||||
|
||||
def get_runtime_default_prompt(self, use_default: str = None) -> list:
|
||||
"""从提示词管理器中获取所需提示词"""
|
||||
import pkg.openai.dprompt as dprompt
|
||||
|
||||
if use_default is None:
|
||||
use_default = dprompt.mode_inst().get_using_name()
|
||||
|
||||
current_default_prompt, _ = dprompt.mode_inst().get_prompt(use_default)
|
||||
return current_default_prompt
|
||||
|
||||
def schedule(self):
|
||||
"""定时会话过期检查任务"""
|
||||
|
||||
def expire_check_timer_loop(self):
|
||||
"""会话过期检查任务"""
|
||||
@@ -222,22 +222,67 @@ class Session:
|
||||
for token_count in counts:
|
||||
total_token_before_query += token_count
|
||||
|
||||
res_text = ""
|
||||
|
||||
pending_msgs = []
|
||||
|
||||
total_tokens = 0
|
||||
|
||||
for resp in pkg.utils.context.get_openai_manager().request_completion(prompts):
|
||||
if resp['choices'][0]['message']['type'] == 'text': # 普通回复
|
||||
res_text += resp['choices'][0]['message']['content']
|
||||
|
||||
total_tokens += resp['usage']['total_tokens']
|
||||
|
||||
pending_msgs.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": resp['choices'][0]['message']['content']
|
||||
}
|
||||
)
|
||||
|
||||
elif resp['choices'][0]['message']['type'] == 'function_call':
|
||||
# self.prompt.append(
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "function call: "+json.dumps(resp['choices'][0]['message']['function_call'])
|
||||
# }
|
||||
# )
|
||||
|
||||
total_tokens += resp['usage']['total_tokens']
|
||||
elif resp['choices'][0]['message']['type'] == 'function_return':
|
||||
# self.prompt.append(
|
||||
# {
|
||||
# "role": "function",
|
||||
# "name": resp['choices'][0]['message']['function_name'],
|
||||
# "content": json.dumps(resp['choices'][0]['message']['content'])
|
||||
# }
|
||||
# )
|
||||
|
||||
# total_tokens += resp['usage']['total_tokens']
|
||||
pass
|
||||
|
||||
|
||||
|
||||
# 向API请求补全
|
||||
message, total_token = pkg.utils.context.get_openai_manager().request_completion(
|
||||
prompts,
|
||||
)
|
||||
# message, total_token = pkg.utils.context.get_openai_manager().request_completion(
|
||||
# prompts,
|
||||
# )
|
||||
|
||||
# 成功获取,处理回复
|
||||
res_test = message
|
||||
res_ans = res_test.strip()
|
||||
# res_test = message
|
||||
res_ans = res_text.strip()
|
||||
|
||||
# 将此次对话的双方内容加入到prompt中
|
||||
# self.prompt.append({'role': 'user', 'content': text})
|
||||
# self.prompt.append({'role': 'assistant', 'content': res_ans})
|
||||
self.prompt.append({'role': 'user', 'content': text})
|
||||
self.prompt.append({'role': 'assistant', 'content': res_ans})
|
||||
# 添加pending_msgs
|
||||
self.prompt += pending_msgs
|
||||
|
||||
# 向token_counts中添加本回合的token数量
|
||||
self.token_counts.append(total_token-total_token_before_query)
|
||||
logging.debug("本回合使用token: {}, session counts: {}".format(total_token-total_token_before_query, self.token_counts))
|
||||
self.token_counts.append(total_tokens-total_token_before_query)
|
||||
logging.debug("本回合使用token: {}, session counts: {}".format(total_tokens-total_token_before_query, self.token_counts))
|
||||
|
||||
if self.just_switched_to_exist_session:
|
||||
self.just_switched_to_exist_session = False
|
||||
|
||||
@@ -45,7 +45,10 @@ __plugins_order__ = []
|
||||
"""插件顺序"""
|
||||
|
||||
__callable_functions__ = []
|
||||
"""供GPT调用的函数"""
|
||||
"""供GPT调用的函数结构"""
|
||||
|
||||
__function_inst_map__: dict[str, callable] = {}
|
||||
"""函数名:实例 映射"""
|
||||
|
||||
|
||||
def generate_plugin_order():
|
||||
@@ -107,6 +110,10 @@ def load_plugins():
|
||||
# 加载插件顺序
|
||||
settings.load_settings()
|
||||
|
||||
# 输出已注册的内容函数列表
|
||||
logging.debug("registered content functions: {}".format(__callable_functions__))
|
||||
logging.debug("function instance map: {}".format(__function_inst_map__))
|
||||
|
||||
|
||||
def initialize_plugins():
|
||||
"""初始化插件"""
|
||||
|
||||
@@ -189,6 +189,11 @@ class Plugin:
|
||||
def wrapper(func):
|
||||
|
||||
function_schema = get_func_schema(func)
|
||||
function_schema['name'] = __current_registering_plugin__ + '-' + func.__name__
|
||||
|
||||
host.__function_inst_map__[function_schema['name']] = function_schema['function']
|
||||
|
||||
del function_schema['function']
|
||||
|
||||
# logging.debug("registering content function: p='{}', f='{}', s={}".format(__current_registering_plugin__, func, function_schema))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user