mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-11 08:16:03 +00:00
chore: 修改包名
This commit is contained in:
1
pkg/provider/__init__.py
Normal file
1
pkg/provider/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""OpenAI 接口处理及会话管理相关"""
|
||||
0
pkg/provider/api/__init__.py
Normal file
0
pkg/provider/api/__init__.py
Normal file
232
pkg/provider/api/chat_completion.py
Normal file
232
pkg/provider/api/chat_completion.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import openai
|
||||
from openai.types.chat import chat_completion_message
|
||||
|
||||
from .model import RequestBase
|
||||
from .. import funcmgr
|
||||
from ...plugin import host
|
||||
from ...utils import context
|
||||
|
||||
|
||||
class ChatCompletionRequest(RequestBase):
|
||||
"""调用ChatCompletion接口的请求类。
|
||||
|
||||
此类保证每一次返回的角色为assistant的信息的finish_reason一定为stop。
|
||||
若有函数调用响应,本类的返回瀑布是:函数调用请求->函数调用结果->...->assistant的信息->stop。
|
||||
"""
|
||||
|
||||
model: str
|
||||
messages: list[dict[str, str]]
|
||||
kwargs: dict
|
||||
|
||||
stopped: bool = False
|
||||
|
||||
pending_func_call: chat_completion_message.FunctionCall = None
|
||||
|
||||
pending_msg: str
|
||||
|
||||
def flush_pending_msg(self):
|
||||
self.append_message(
|
||||
role="assistant",
|
||||
content=self.pending_msg
|
||||
)
|
||||
self.pending_msg = ""
|
||||
|
||||
def append_message(self, role: str, content: str, name: str=None, function_call: dict=None):
|
||||
msg = {
|
||||
"role": role,
|
||||
"content": content
|
||||
}
|
||||
|
||||
if name is not None:
|
||||
msg['name'] = name
|
||||
|
||||
if function_call is not None:
|
||||
msg['function_call'] = function_call
|
||||
|
||||
self.messages.append(msg)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: openai.Client,
|
||||
model: str,
|
||||
messages: list[dict[str, str]],
|
||||
**kwargs
|
||||
):
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.messages = messages.copy()
|
||||
|
||||
self.kwargs = kwargs
|
||||
|
||||
self.req_func = self.client.chat.completions.create
|
||||
|
||||
self.pending_func_call = None
|
||||
|
||||
self.stopped = False
|
||||
|
||||
self.pending_msg = ""
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> dict:
|
||||
if self.stopped:
|
||||
raise StopIteration()
|
||||
|
||||
if self.pending_func_call is None: # 没有待处理的函数调用请求
|
||||
|
||||
args = {
|
||||
"model": self.model,
|
||||
"messages": self.messages,
|
||||
}
|
||||
|
||||
funcs = funcmgr.get_func_schema_list()
|
||||
|
||||
if len(funcs) > 0:
|
||||
args['functions'] = funcs
|
||||
|
||||
# 拼接kwargs
|
||||
args = {**args, **self.kwargs}
|
||||
|
||||
from openai.types.chat import chat_completion
|
||||
|
||||
resp: chat_completion.ChatCompletion = self._req(**args)
|
||||
|
||||
choice0 = resp.choices[0]
|
||||
|
||||
# 如果不是函数调用,且finish_reason为stop,则停止迭代
|
||||
if choice0.finish_reason == 'stop': # and choice0["finish_reason"] == "stop"
|
||||
self.stopped = True
|
||||
|
||||
if hasattr(choice0.message, 'function_call') and choice0.message.function_call is not None:
|
||||
self.pending_func_call = choice0.message.function_call
|
||||
|
||||
self.append_message(
|
||||
role="assistant",
|
||||
content=choice0.message.content,
|
||||
function_call=choice0.message.function_call
|
||||
)
|
||||
|
||||
return {
|
||||
"id": resp.id,
|
||||
"choices": [
|
||||
{
|
||||
"index": choice0.index,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"type": "function_call",
|
||||
"content": choice0.message.content,
|
||||
"function_call": {
|
||||
"name": choice0.message.function_call.name,
|
||||
"arguments": choice0.message.function_call.arguments
|
||||
}
|
||||
},
|
||||
"finish_reason": "function_call"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": resp.usage.prompt_tokens,
|
||||
"completion_tokens": resp.usage.completion_tokens,
|
||||
"total_tokens": resp.usage.total_tokens
|
||||
}
|
||||
}
|
||||
else:
|
||||
|
||||
# self.pending_msg += choice0['message']['content']
|
||||
# 普通回复一定处于最后方,故不用再追加进内部messages
|
||||
|
||||
return {
|
||||
"id": resp.id,
|
||||
"choices": [
|
||||
{
|
||||
"index": choice0.index,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"type": "text",
|
||||
"content": choice0.message.content
|
||||
},
|
||||
"finish_reason": choice0.finish_reason
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": resp.usage.prompt_tokens,
|
||||
"completion_tokens": resp.usage.completion_tokens,
|
||||
"total_tokens": resp.usage.total_tokens
|
||||
}
|
||||
}
|
||||
else: # 处理函数调用请求
|
||||
|
||||
cp_pending_func_call = self.pending_func_call.copy()
|
||||
|
||||
self.pending_func_call = None
|
||||
|
||||
func_name = cp_pending_func_call.name
|
||||
arguments = {}
|
||||
|
||||
try:
|
||||
|
||||
try:
|
||||
arguments = json.loads(cp_pending_func_call.arguments)
|
||||
# 若不是json格式的异常处理
|
||||
except json.decoder.JSONDecodeError:
|
||||
# 获取函数的参数列表
|
||||
func_schema = funcmgr.get_func_schema(func_name)
|
||||
|
||||
arguments = {
|
||||
func_schema['parameters']['required'][0]: cp_pending_func_call.arguments
|
||||
}
|
||||
|
||||
logging.info("执行函数调用: name={}, arguments={}".format(func_name, arguments))
|
||||
|
||||
# 执行函数调用
|
||||
ret = ""
|
||||
try:
|
||||
ret = funcmgr.execute_function(func_name, arguments)
|
||||
|
||||
logging.info("函数执行完成。")
|
||||
except Exception as e:
|
||||
ret = "error: execute function failed: {}".format(str(e))
|
||||
logging.error("函数执行失败: {}".format(str(e)))
|
||||
|
||||
# 上报数据
|
||||
plugin_info = host.get_plugin_info_for_audit(func_name.split('-')[0])
|
||||
audit_func_name = func_name.split('-')[1]
|
||||
audit_func_desc = funcmgr.get_func_schema(func_name)['description']
|
||||
context.get_center_v2_api().usage.post_function_record(
|
||||
plugin=plugin_info,
|
||||
function_name=audit_func_name,
|
||||
function_description=audit_func_desc,
|
||||
)
|
||||
|
||||
self.append_message(
|
||||
role="function",
|
||||
content=json.dumps(ret, ensure_ascii=False),
|
||||
name=func_name
|
||||
)
|
||||
|
||||
return {
|
||||
"id": -1,
|
||||
"choices": [
|
||||
{
|
||||
"index": -1,
|
||||
"message": {
|
||||
"role": "function",
|
||||
"type": "function_return",
|
||||
"function_name": func_name,
|
||||
"content": json.dumps(ret, ensure_ascii=False)
|
||||
},
|
||||
"finish_reason": "function_return"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
except funcmgr.ContentFunctionNotFoundError:
|
||||
raise Exception("没有找到函数: {}".format(func_name))
|
||||
100
pkg/provider/api/completion.py
Normal file
100
pkg/provider/api/completion.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import openai
|
||||
from openai.types import completion, completion_choice
|
||||
|
||||
from . import model
|
||||
|
||||
|
||||
class CompletionRequest(model.RequestBase):
|
||||
"""调用Completion接口的请求类。
|
||||
|
||||
调用方可以一直next completion直到finish_reason为stop。
|
||||
"""
|
||||
|
||||
model: str
|
||||
prompt: str
|
||||
kwargs: dict
|
||||
|
||||
stopped: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: openai.Client,
|
||||
model: str,
|
||||
messages: list[dict[str, str]],
|
||||
**kwargs
|
||||
):
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.prompt = ""
|
||||
|
||||
for message in messages:
|
||||
self.prompt += message["role"] + ": " + message["content"] + "\n"
|
||||
|
||||
self.prompt += "assistant: "
|
||||
|
||||
self.kwargs = kwargs
|
||||
|
||||
self.req_func = self.client.completions.create
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> dict:
|
||||
"""调用Completion接口,返回生成的文本
|
||||
|
||||
{
|
||||
"id": "id",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"type": "text",
|
||||
"content": "message"
|
||||
},
|
||||
"finish_reason": "reason"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
if self.stopped:
|
||||
raise StopIteration()
|
||||
|
||||
resp: completion.Completion = self._req(
|
||||
model=self.model,
|
||||
prompt=self.prompt,
|
||||
**self.kwargs
|
||||
)
|
||||
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
self.stopped = True
|
||||
|
||||
choice0: completion_choice.CompletionChoice = resp.choices[0]
|
||||
|
||||
self.prompt += choice0.text
|
||||
|
||||
return {
|
||||
"id": resp.id,
|
||||
"choices": [
|
||||
{
|
||||
"index": choice0.index,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"type": "text",
|
||||
"content": choice0.text
|
||||
},
|
||||
"finish_reason": choice0.finish_reason
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": resp.usage.prompt_tokens,
|
||||
"completion_tokens": resp.usage.completion_tokens,
|
||||
"total_tokens": resp.usage.total_tokens
|
||||
}
|
||||
}
|
||||
40
pkg/provider/api/model.py
Normal file
40
pkg/provider/api/model.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# 定义不同接口请求的模型
|
||||
import logging
|
||||
|
||||
import openai
|
||||
|
||||
from ...utils import context
|
||||
|
||||
|
||||
class RequestBase:
|
||||
|
||||
client: openai.Client
|
||||
|
||||
req_func: callable
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def _next_key(self):
|
||||
switched, name = context.get_openai_manager().key_mgr.auto_switch()
|
||||
logging.debug("切换api-key: switched={}, name={}".format(switched, name))
|
||||
self.client.api_key = context.get_openai_manager().key_mgr.get_using_key()
|
||||
|
||||
def _req(self, **kwargs):
|
||||
"""处理代理问题"""
|
||||
logging.debug("请求接口参数: %s", str(kwargs))
|
||||
config = context.get_config_manager().data
|
||||
|
||||
ret = self.req_func(**kwargs)
|
||||
logging.debug("接口请求返回:%s", str(ret))
|
||||
|
||||
if config['switch_strategy'] == 'active':
|
||||
self._next_key()
|
||||
|
||||
return ret
|
||||
|
||||
def __iter__(self):
|
||||
raise self
|
||||
|
||||
def __next__(self):
|
||||
raise NotImplementedError
|
||||
33
pkg/provider/entities.py
Normal file
33
pkg/provider/entities.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import enum
|
||||
import pydantic
|
||||
|
||||
|
||||
class FunctionCall(pydantic.BaseModel):
|
||||
name: str
|
||||
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(pydantic.BaseModel):
|
||||
id: str
|
||||
|
||||
type: str
|
||||
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class Message(pydantic.BaseModel):
|
||||
role: str
|
||||
|
||||
name: typing.Optional[str] = None
|
||||
|
||||
content: typing.Optional[str] = None
|
||||
|
||||
function_call: typing.Optional[FunctionCall] = None
|
||||
|
||||
tool_calls: typing.Optional[list[ToolCall]] = None
|
||||
|
||||
tool_call_id: typing.Optional[str] = None
|
||||
139
pkg/provider/modelmgr.py
Normal file
139
pkg/provider/modelmgr.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""OpenAI 接口底层封装
|
||||
|
||||
目前使用的对话接口有:
|
||||
ChatCompletion - gpt-3.5-turbo 等模型
|
||||
Completion - text-davinci-003 等模型
|
||||
此模块封装此两个接口的请求实现,为上层提供统一的调用方式
|
||||
"""
|
||||
import tiktoken
|
||||
import openai
|
||||
|
||||
from ..provider.api import model as api_model
|
||||
from ..provider.api import completion as api_completion
|
||||
from ..provider.api import chat_completion as api_chat_completion
|
||||
|
||||
COMPLETION_MODELS = {
|
||||
"gpt-3.5-turbo-instruct",
|
||||
}
|
||||
|
||||
CHAT_COMPLETION_MODELS = {
|
||||
# GPT 4 系列
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-0314", # legacy
|
||||
"gpt-4-32k-0314", # legacy
|
||||
# GPT 3.5 系列
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0613", # legacy
|
||||
"gpt-3.5-turbo-16k-0613", # legacy
|
||||
"gpt-3.5-turbo-0301", # legacy
|
||||
# One-API 接入
|
||||
"SparkDesk",
|
||||
"chatglm_pro",
|
||||
"chatglm_std",
|
||||
"chatglm_lite",
|
||||
"qwen-v1",
|
||||
"qwen-plus-v1",
|
||||
"ERNIE-Bot",
|
||||
"ERNIE-Bot-turbo",
|
||||
"gemini-pro",
|
||||
}
|
||||
|
||||
EDIT_MODELS = {
|
||||
|
||||
}
|
||||
|
||||
IMAGE_MODELS = {
|
||||
|
||||
}
|
||||
|
||||
|
||||
def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> api_model.RequestBase:
|
||||
if model_name in CHAT_COMPLETION_MODELS:
|
||||
return api_chat_completion.ChatCompletionRequest(client, model_name, messages, **args)
|
||||
elif model_name in COMPLETION_MODELS:
|
||||
return api_completion.CompletionRequest(client, model_name, messages, **args)
|
||||
raise ValueError("不支持模型[{}],请检查配置文件".format(model_name))
|
||||
|
||||
|
||||
def count_chat_completion_tokens(messages: list, model: str) -> int:
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
if model in {
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"SparkDesk",
|
||||
"chatglm_pro",
|
||||
"chatglm_std",
|
||||
"chatglm_lite",
|
||||
"qwen-v1",
|
||||
"qwen-plus-v1",
|
||||
"ERNIE-Bot",
|
||||
"ERNIE-Bot-turbo",
|
||||
"gemini-pro",
|
||||
}:
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
elif model == "gpt-3.5-turbo-0301":
|
||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif "gpt-3.5-turbo" in model:
|
||||
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
||||
return count_chat_completion_tokens(messages, model="gpt-3.5-turbo-0613")
|
||||
elif "gpt-4" in model:
|
||||
# print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
return count_chat_completion_tokens(messages, model="gpt-4-0613")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"""count_chat_completion_tokens() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
|
||||
)
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
|
||||
|
||||
def count_completion_tokens(messages: list, model: str) -> int:
|
||||
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
text = ""
|
||||
|
||||
for message in messages:
|
||||
text += message['role'] + message['content'] + "\n"
|
||||
|
||||
text += "assistant: "
|
||||
|
||||
return len(encoding.encode(text))
|
||||
|
||||
|
||||
def count_tokens(messages: list, model: str):
|
||||
|
||||
if model in CHAT_COMPLETION_MODELS:
|
||||
return count_chat_completion_tokens(messages, model)
|
||||
elif model in COMPLETION_MODELS:
|
||||
return count_completion_tokens(messages, model)
|
||||
raise ValueError("不支持模型[{}],请检查配置文件".format(model))
|
||||
0
pkg/provider/requester/__init__.py
Normal file
0
pkg/provider/requester/__init__.py
Normal file
31
pkg/provider/requester/api.py
Normal file
31
pkg/provider/requester/api.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
from ..session import entities as session_entities
|
||||
|
||||
class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""LLM API请求器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def request(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
conversation: session_entities.Conversation,
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""请求
|
||||
"""
|
||||
raise NotImplementedError
|
||||
0
pkg/provider/requester/apis/__init__.py
Normal file
0
pkg/provider/requester/apis/__init__.py
Normal file
140
pkg/provider/requester/apis/chatcmpl.py
Normal file
140
pkg/provider/requester/apis/chatcmpl.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
import json
|
||||
|
||||
import openai
|
||||
import openai.types.chat.chat_completion as chat_completion
|
||||
|
||||
from .. import api
|
||||
from ....core import entities as core_entities
|
||||
from ... import entities as llm_entities
|
||||
from ...session import entities as session_entities
|
||||
|
||||
|
||||
class OpenAIChatCompletion(api.LLMAPIRequester):
|
||||
client: openai.AsyncClient
|
||||
|
||||
async def initialize(self):
|
||||
self.client = openai.AsyncClient(
|
||||
api_key="",
|
||||
base_url=self.ap.cfg_mgr.data["openai_config"]["reverse_proxy"],
|
||||
timeout=self.ap.cfg_mgr.data["process_message_timeout"],
|
||||
)
|
||||
|
||||
async def _req(
|
||||
self,
|
||||
args: dict,
|
||||
) -> chat_completion.ChatCompletion:
|
||||
self.ap.logger.debug(f"req chat_completion with args {args}")
|
||||
return await self.client.chat.completions.create(**args)
|
||||
|
||||
async def _make_msg(
|
||||
self,
|
||||
chat_completion: chat_completion.ChatCompletion,
|
||||
) -> llm_entities.Message:
|
||||
chatcmpl_message = chat_completion.choices[0].message.dict()
|
||||
|
||||
message = llm_entities.Message(**chatcmpl_message)
|
||||
|
||||
return message
|
||||
|
||||
async def _closure(
|
||||
self,
|
||||
req_messages: list[dict],
|
||||
conversation: session_entities.Conversation,
|
||||
user_text: str = None,
|
||||
function_ret: str = None,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = conversation.use_model.token_mgr.get_token()
|
||||
|
||||
args = self.ap.cfg_mgr.data["completion_api_params"].copy()
|
||||
args["model"] = conversation.use_model.name
|
||||
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation)
|
||||
# tools = [
|
||||
# {
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "get_current_weather",
|
||||
# "description": "Get the current weather in a given location",
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "location": {
|
||||
# "type": "string",
|
||||
# "description": "The city and state, e.g. San Francisco, CA",
|
||||
# },
|
||||
# "unit": {
|
||||
# "type": "string",
|
||||
# "enum": ["celsius", "fahrenheit"],
|
||||
# },
|
||||
# },
|
||||
# "required": ["location"],
|
||||
# },
|
||||
# },
|
||||
# }
|
||||
# ]
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages
|
||||
args["messages"] = messages
|
||||
|
||||
# 发送请求
|
||||
resp = await self._req(args)
|
||||
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
|
||||
return message
|
||||
|
||||
async def request(
|
||||
self, query: core_entities.Query, conversation: session_entities.Conversation
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""请求"""
|
||||
|
||||
pending_tool_calls = []
|
||||
|
||||
req_messages = [
|
||||
m.dict(exclude_none=True) for m in conversation.prompt.messages
|
||||
] + [m.dict(exclude_none=True) for m in conversation.messages]
|
||||
|
||||
# req_messages.append({"role": "user", "content": str(query.message_chain)})
|
||||
|
||||
msg = await self._closure(req_messages, conversation)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg.dict(exclude_none=True))
|
||||
|
||||
while pending_tool_calls:
|
||||
for tool_call in pending_tool_calls:
|
||||
func = tool_call.function
|
||||
|
||||
parameters = json.loads(func.arguments)
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(
|
||||
query, func.name, parameters
|
||||
)
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
req_messages.append(msg.dict(exclude_none=True))
|
||||
|
||||
# 处理完所有调用,继续请求
|
||||
msg = await self._closure(req_messages, conversation)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg.dict(exclude_none=True))
|
||||
23
pkg/provider/requester/entities.py
Normal file
23
pkg/provider/requester/entities.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
from . import api
|
||||
from . import token
|
||||
|
||||
|
||||
class LLMModelInfo(pydantic.BaseModel):
|
||||
"""模型"""
|
||||
|
||||
name: str
|
||||
|
||||
provider: str
|
||||
|
||||
token_mgr: token.TokenManager
|
||||
|
||||
requester: api.LLMAPIRequester
|
||||
|
||||
function_call_supported: typing.Optional[bool] = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
41
pkg/provider/requester/modelmgr.py
Normal file
41
pkg/provider/requester/modelmgr.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import entities
|
||||
from ...core import app
|
||||
|
||||
from .apis import chatcmpl
|
||||
from . import token
|
||||
|
||||
|
||||
class ModelManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
model_list: list[entities.LLMModelInfo]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.model_list = []
|
||||
|
||||
async def initialize(self):
|
||||
openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap)
|
||||
await openai_chat_completion.initialize()
|
||||
openai_token_mgr = token.TokenManager(self.ap, list(self.ap.cfg_mgr.data['openai_config']['api_key'].values()))
|
||||
|
||||
self.model_list.append(
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-3.5-turbo",
|
||||
provider="openai",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
function_call_supported=True
|
||||
)
|
||||
)
|
||||
|
||||
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
|
||||
"""通过名称获取模型
|
||||
"""
|
||||
for model in self.model_list:
|
||||
if model.name == name:
|
||||
return model
|
||||
raise ValueError(f"Model {name} not found")
|
||||
25
pkg/provider/requester/token.py
Normal file
25
pkg/provider/requester/token.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class TokenManager():
|
||||
|
||||
provider: str
|
||||
|
||||
tokens: list[str]
|
||||
|
||||
using_token_index: typing.Optional[int] = 0
|
||||
|
||||
def __init__(self, provider: str, tokens: list[str]):
|
||||
self.provider = provider
|
||||
self.tokens = tokens
|
||||
self.using_token_index = 0
|
||||
|
||||
def get_token(self) -> str:
|
||||
return self.tokens[self.using_token_index]
|
||||
|
||||
def next_token(self):
|
||||
self.using_token_index = (self.using_token_index + 1) % len(self.tokens)
|
||||
0
pkg/provider/session/__init__.py
Normal file
0
pkg/provider/session/__init__.py
Normal file
53
pkg/provider/session/entities.py
Normal file
53
pkg/provider/session/entities.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
from ..sysprompt import entities as sysprompt_entities
|
||||
from .. import entities as llm_entities
|
||||
from ..requester import entities
|
||||
from ...core import entities as core_entities
|
||||
from ..tools import entities as tools_entities
|
||||
|
||||
|
||||
class Conversation(pydantic.BaseModel):
|
||||
"""对话"""
|
||||
|
||||
prompt: sysprompt_entities.Prompt
|
||||
|
||||
messages: list[llm_entities.Message]
|
||||
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
use_model: entities.LLMModelInfo
|
||||
|
||||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
|
||||
|
||||
|
||||
class Session(pydantic.BaseModel):
|
||||
"""会话"""
|
||||
launcher_type: core_entities.LauncherTypes
|
||||
|
||||
launcher_id: int
|
||||
|
||||
sender_id: typing.Optional[int] = 0
|
||||
|
||||
use_prompt_name: typing.Optional[str] = 'default'
|
||||
|
||||
using_conversation: typing.Optional[Conversation] = None
|
||||
|
||||
conversations: typing.Optional[list[Conversation]] = []
|
||||
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
semaphore: typing.Optional[asyncio.Semaphore] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
51
pkg/provider/session/sessionmgr.py
Normal file
51
pkg/provider/session/sessionmgr.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
|
||||
|
||||
class SessionManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
session_list: list[entities.Session]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.session_list = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def get_session(self, query: core_entities.Query) -> entities.Session:
|
||||
"""获取会话
|
||||
"""
|
||||
for session in self.session_list:
|
||||
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
|
||||
return session
|
||||
|
||||
session = entities.Session(
|
||||
launcher_type=query.launcher_type,
|
||||
launcher_id=query.launcher_id,
|
||||
semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000),
|
||||
)
|
||||
self.session_list.append(session)
|
||||
return session
|
||||
|
||||
async def get_conversation(self, session: entities.Session) -> entities.Conversation:
|
||||
if not session.conversations:
|
||||
session.conversations = []
|
||||
|
||||
if session.using_conversation is None:
|
||||
conversation = entities.Conversation(
|
||||
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
|
||||
messages=[],
|
||||
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']),
|
||||
use_funcs=await self.ap.tool_mgr.get_all_functions(),
|
||||
)
|
||||
session.conversations.append(conversation)
|
||||
session.using_conversation = conversation
|
||||
|
||||
return session.using_conversation
|
||||
0
pkg/provider/sysprompt/__init__.py
Normal file
0
pkg/provider/sysprompt/__init__.py
Normal file
14
pkg/provider/sysprompt/entities.py
Normal file
14
pkg/provider/sysprompt/entities.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import pydantic
|
||||
|
||||
from ...provider import entities
|
||||
|
||||
|
||||
class Prompt(pydantic.BaseModel):
|
||||
"""供AI使用的Prompt"""
|
||||
|
||||
name: str
|
||||
|
||||
messages: list[entities.Message]
|
||||
32
pkg/provider/sysprompt/loader.py
Normal file
32
pkg/provider/sysprompt/loader.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
class PromptLoader(metaclass=abc.ABCMeta):
|
||||
"""Prompt加载器抽象类
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
prompts: list[entities.Prompt]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.prompts = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_prompts(self) -> list[entities.Prompt]:
|
||||
"""获取Prompt列表
|
||||
"""
|
||||
return self.prompts
|
||||
0
pkg/provider/sysprompt/loaders/__init__.py
Normal file
0
pkg/provider/sysprompt/loaders/__init__.py
Normal file
38
pkg/provider/sysprompt/loaders/scenario.py
Normal file
38
pkg/provider/sysprompt/loaders/scenario.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from .. import loader
|
||||
from .. import entities
|
||||
from ....provider import entities as llm_entities
|
||||
|
||||
|
||||
class ScenarioPromptLoader(loader.PromptLoader):
|
||||
"""加载scenario目录下的json"""
|
||||
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
for file in os.listdir("scenarios"):
|
||||
with open("scenarios/{}".format(file), "r", encoding="utf-8") as f:
|
||||
file_str = f.read()
|
||||
file_name = file.split(".")[0]
|
||||
file_json = json.loads(file_str)
|
||||
messages = []
|
||||
for msg in file_json["prompt"]:
|
||||
role = 'system'
|
||||
if "role" in msg:
|
||||
role = msg['role']
|
||||
messages.append(
|
||||
llm_entities.Message(
|
||||
role=role,
|
||||
content=msg['content'],
|
||||
)
|
||||
)
|
||||
prompt = entities.Prompt(
|
||||
name=file_name,
|
||||
messages=messages
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
|
||||
42
pkg/provider/sysprompt/loaders/single.py
Normal file
42
pkg/provider/sysprompt/loaders/single.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
|
||||
from .. import loader
|
||||
from .. import entities
|
||||
from ....provider import entities as llm_entities
|
||||
|
||||
|
||||
class SingleSystemPromptLoader(loader.PromptLoader):
|
||||
"""配置文件中的单条system prompt的prompt加载器
|
||||
"""
|
||||
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
|
||||
for name, cnt in self.ap.cfg_mgr.data['default_prompt'].items():
|
||||
prompt = entities.Prompt(
|
||||
name=name,
|
||||
messages=[
|
||||
llm_entities.Message(
|
||||
role='system',
|
||||
content=cnt
|
||||
)
|
||||
]
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
|
||||
for file in os.listdir("prompts"):
|
||||
with open("prompts/{}".format(file), "r", encoding="utf-8") as f:
|
||||
file_str = f.read()
|
||||
file_name = file.split(".")[0]
|
||||
prompt = entities.Prompt(
|
||||
name=file_name,
|
||||
messages=[
|
||||
llm_entities.Message(
|
||||
role='system',
|
||||
content=file_str
|
||||
)
|
||||
]
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
50
pkg/provider/sysprompt/sysprompt.py
Normal file
50
pkg/provider/sysprompt/sysprompt.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ...core import app
|
||||
from . import loader
|
||||
from .loaders import single, scenario
|
||||
|
||||
|
||||
class PromptManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
loader_inst: loader.PromptLoader
|
||||
|
||||
default_prompt: str = 'default'
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
loader_map = {
|
||||
"normal": single.SingleSystemPromptLoader,
|
||||
"full_scenario": scenario.ScenarioPromptLoader
|
||||
}
|
||||
|
||||
loader_cls = loader_map[self.ap.cfg_mgr.data['preset_mode']]
|
||||
|
||||
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
|
||||
|
||||
await self.loader_inst.initialize()
|
||||
await self.loader_inst.load()
|
||||
|
||||
def get_all_prompts(self) -> list[loader.entities.Prompt]:
|
||||
"""获取所有Prompt
|
||||
"""
|
||||
return self.loader_inst.get_prompts()
|
||||
|
||||
async def get_prompt(self, name: str) -> loader.entities.Prompt:
|
||||
"""获取Prompt
|
||||
"""
|
||||
for prompt in self.get_all_prompts():
|
||||
if prompt.name == name:
|
||||
return prompt
|
||||
|
||||
async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt:
|
||||
"""通过前缀获取Prompt
|
||||
"""
|
||||
for prompt in self.get_all_prompts():
|
||||
if prompt.name.startswith(prefix):
|
||||
return prompt
|
||||
0
pkg/provider/tools/__init__.py
Normal file
0
pkg/provider/tools/__init__.py
Normal file
35
pkg/provider/tools/entities.py
Normal file
35
pkg/provider/tools/entities.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
import asyncio
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class LLMFunction(pydantic.BaseModel):
|
||||
"""函数"""
|
||||
|
||||
name: str
|
||||
"""函数名"""
|
||||
|
||||
human_desc: str
|
||||
|
||||
description: str
|
||||
"""给LLM识别的函数描述"""
|
||||
|
||||
enable: typing.Optional[bool] = True
|
||||
|
||||
parameters: dict
|
||||
|
||||
func: typing.Callable
|
||||
"""供调用的python异步方法
|
||||
|
||||
此异步方法第一个参数接收当前请求的query对象,可以从其中取出session等信息。
|
||||
query参数不在parameters中,但在调用时会自动传入。
|
||||
但在当前版本中,插件提供的内容函数都是同步的,且均为请求无关的,故在此版本的实现(以及考虑了向后兼容性的版本)中,
|
||||
对插件的内容函数进行封装并存到这里来。
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
99
pkg/provider/tools/toolmgr.py
Normal file
99
pkg/provider/tools/toolmgr.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
from ..session import entities as session_entities
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""LLM工具管理器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
all_functions: list[entities.LLMFunction]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.all_functions = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def register_legacy_function(self, name: str, description: str, parameters: dict, func: callable):
|
||||
"""注册函数
|
||||
"""
|
||||
async def wrapper(query, **kwargs):
|
||||
return func(**kwargs)
|
||||
function = entities.LLMFunction(
|
||||
name=name,
|
||||
description=description,
|
||||
human_desc='',
|
||||
enable=True,
|
||||
parameters=parameters,
|
||||
func=wrapper
|
||||
)
|
||||
self.all_functions.append(function)
|
||||
|
||||
async def register_function(self, function: entities.LLMFunction):
|
||||
"""添加函数
|
||||
"""
|
||||
self.all_functions.append(function)
|
||||
|
||||
async def get_function(self, name: str) -> entities.LLMFunction:
|
||||
"""获取函数
|
||||
"""
|
||||
for function in self.all_functions:
|
||||
if function.name == name:
|
||||
return function
|
||||
return None
|
||||
|
||||
async def get_all_functions(self) -> list[entities.LLMFunction]:
|
||||
"""获取所有函数
|
||||
"""
|
||||
return self.all_functions
|
||||
|
||||
async def generate_tools_for_openai(self, conversation: session_entities.Conversation) -> str:
|
||||
"""生成函数列表
|
||||
"""
|
||||
tools = []
|
||||
|
||||
for function in conversation.use_funcs:
|
||||
if function.enable:
|
||||
function_schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
"parameters": function.parameters
|
||||
}
|
||||
}
|
||||
tools.append(function_schema)
|
||||
|
||||
return tools
|
||||
|
||||
async def execute_func_call(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
name: str,
|
||||
parameters: dict
|
||||
) -> typing.Any:
|
||||
"""执行函数调用
|
||||
"""
|
||||
|
||||
# return "i'm not sure for the args "+str(parameters)
|
||||
|
||||
function = await self.get_function(name)
|
||||
if function is None:
|
||||
return None
|
||||
|
||||
parameters = parameters.copy()
|
||||
|
||||
parameters = {
|
||||
"query": query,
|
||||
**parameters
|
||||
}
|
||||
|
||||
return await function.func(**parameters)
|
||||
Reference in New Issue
Block a user