mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-12 00:36:03 +00:00
feat: preliminarily implement pipeline invoking
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import sqlalchemy
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from . import entities, requester
|
||||
from ...core import app
|
||||
@@ -16,23 +17,6 @@ from .requesters import bailianchatcmpl, chatcmpl, anthropicmsgs, moonshotchatcm
|
||||
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
|
||||
|
||||
|
||||
class RuntimeLLMModel:
|
||||
"""运行时模型"""
|
||||
|
||||
model_entity: persistence_model.LLMModel
|
||||
"""模型数据"""
|
||||
|
||||
token_mgr: token.TokenManager
|
||||
"""api key管理器"""
|
||||
|
||||
requester: requester.LLMAPIRequester
|
||||
"""请求器实例"""
|
||||
|
||||
def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: requester.LLMAPIRequester):
|
||||
self.model_entity = model_entity
|
||||
self.token_mgr = token_mgr
|
||||
self.requester = requester
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""模型管理器"""
|
||||
@@ -47,7 +31,7 @@ class ModelManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
llm_models: list[RuntimeLLMModel]
|
||||
llm_models: list[requester.RuntimeLLMModel]
|
||||
|
||||
requester_components: list[engine.Component]
|
||||
|
||||
@@ -99,16 +83,20 @@ class ModelManager:
|
||||
elif isinstance(model_info, dict):
|
||||
model_info = persistence_model.LLMModel(**model_info)
|
||||
|
||||
runtime_llm_model = RuntimeLLMModel(
|
||||
requester_inst = self.requester_dict[model_info.requester](
|
||||
ap=self.ap,
|
||||
config=model_info.requester_config
|
||||
)
|
||||
|
||||
await requester_inst.initialize()
|
||||
|
||||
runtime_llm_model = requester.RuntimeLLMModel(
|
||||
model_entity=model_info,
|
||||
token_mgr=token.TokenManager(
|
||||
name=model_info.uuid,
|
||||
tokens=model_info.api_keys,
|
||||
),
|
||||
requester=self.requester_dict[model_info.requester](
|
||||
ap=self.ap,
|
||||
config=model_info.requester_config
|
||||
)
|
||||
requester=requester_inst
|
||||
)
|
||||
self.llm_models.append(runtime_llm_model)
|
||||
|
||||
|
||||
@@ -6,8 +6,27 @@ import typing
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
from . import entities as modelmgr_entities
|
||||
from ..tools import entities as tools_entities
|
||||
from ...entity.persistence import model as persistence_model
|
||||
from . import token
|
||||
|
||||
|
||||
class RuntimeLLMModel:
|
||||
"""运行时模型"""
|
||||
|
||||
model_entity: persistence_model.LLMModel
|
||||
"""模型数据"""
|
||||
|
||||
token_mgr: token.TokenManager
|
||||
"""api key管理器"""
|
||||
|
||||
requester: LLMAPIRequester
|
||||
"""请求器实例"""
|
||||
|
||||
def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: LLMAPIRequester):
|
||||
self.model_entity = model_entity
|
||||
self.token_mgr = token_mgr
|
||||
self.requester = requester
|
||||
|
||||
|
||||
class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
@@ -31,21 +50,11 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def preprocess(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
):
|
||||
"""预处理
|
||||
|
||||
在这里处理特定API对Query对象的兼容性问题。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def call(
|
||||
async def invoke_llm(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: modelmgr_entities.LLMModelInfo,
|
||||
model: RuntimeLLMModel,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
@@ -53,7 +62,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""调用API
|
||||
|
||||
Args:
|
||||
model (modelmgr_entities.LLMModelInfo): 使用的模型信息
|
||||
model (RuntimeLLMModel): 使用的模型信息
|
||||
messages (typing.List[llm_entities.Message]): 消息对象列表
|
||||
funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None.
|
||||
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
|
||||
|
||||
@@ -24,16 +24,16 @@ class AnthropicMessages(requester.LLMAPIRequester):
|
||||
client: anthropic.AsyncAnthropic
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://api.anthropic.com/v1',
|
||||
'base_url': 'https://api.anthropic.com/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
httpx_client = anthropic._base_client.AsyncHttpxClientWrapper(
|
||||
base_url=self.ap.provider_cfg.data['requester']['anthropic-messages']['base-url'],
|
||||
base_url=self.requester_cfg['base_url'],
|
||||
# cast to a valid type because mypy doesn't understand our type narrowing
|
||||
timeout=typing.cast(httpx.Timeout, self.ap.provider_cfg.data['requester']['anthropic-messages']['timeout']),
|
||||
timeout=typing.cast(httpx.Timeout, self.requester_cfg['timeout']),
|
||||
limits=anthropic._constants.DEFAULT_CONNECTION_LIMITS,
|
||||
follow_redirects=True,
|
||||
trust_env=True,
|
||||
@@ -44,17 +44,18 @@ class AnthropicMessages(requester.LLMAPIRequester):
|
||||
http_client=httpx_client,
|
||||
)
|
||||
|
||||
async def call(
|
||||
async def invoke_llm(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: entities.LLMModelInfo,
|
||||
model: requester.RuntimeLLMModel,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = model.token_mgr.get_token()
|
||||
|
||||
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
|
||||
args["model"] = model.name if model.model_name is None else model.model_name
|
||||
args = extra_args.copy()
|
||||
args["model"] = model.model_entity.name
|
||||
|
||||
# 处理消息
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: Anthropic
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -14,6 +14,6 @@ class BailianChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
'base_url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: 阿里云百炼
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -26,7 +26,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
"base-url": "https://api.openai.com/v1",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
|
||||
self.client = openai.AsyncClient(
|
||||
api_key="",
|
||||
base_url=self.requester_cfg["base-url"],
|
||||
base_url=self.requester_cfg["base_url"],
|
||||
timeout=self.requester_cfg["timeout"],
|
||||
http_client=httpx.AsyncClient(
|
||||
trust_env=True, timeout=self.requester_cfg["timeout"]
|
||||
@@ -65,16 +65,14 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {}, # TODO: 所有的args都改为从此参数读取
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg["args"].copy()
|
||||
args["model"] = (
|
||||
use_model.name if use_model.model_name is None else use_model.model_name
|
||||
)
|
||||
args = extra_args.copy()
|
||||
args["model"] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
@@ -104,10 +102,10 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
|
||||
return message
|
||||
|
||||
async def call(
|
||||
async def invoke_llm(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: entities.LLMModelInfo,
|
||||
model: requester.RuntimeLLMModel,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: OpenAI
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -13,7 +13,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
"""Deepseek ChatCompletion API 请求器"""
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://api.deepseek.com',
|
||||
'base_url': 'https://api.deepseek.com',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -21,14 +21,14 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
args = extra_args.copy()
|
||||
args["model"] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: 深度求索
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -18,7 +18,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
"""Gitee AI ChatCompletions API 请求器"""
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://ai.gitee.com/v1',
|
||||
'base_url': 'https://ai.gitee.com/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -26,14 +26,14 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
args = extra_args.copy()
|
||||
args["model"] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: Gitee AI
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -14,6 +14,6 @@ class LmStudioChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'http://127.0.0.1:1234/v1',
|
||||
'base_url': 'http://127.0.0.1:1234/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: LM Studio
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -15,7 +15,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
"""Moonshot ChatCompletion API 请求器"""
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://api.moonshot.cn/v1',
|
||||
'base_url': 'https://api.moonshot.cn/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -23,14 +23,14 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
args = extra_args.copy()
|
||||
args["model"] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: 月之暗面
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -22,35 +22,38 @@ REQUESTER_NAME: str = "ollama-chat"
|
||||
|
||||
class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
"""Ollama平台 ChatCompletion API请求器"""
|
||||
|
||||
client: ollama.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'http://127.0.0.1:11434',
|
||||
'timeout': 120,
|
||||
"base_url": "http://127.0.0.1:11434",
|
||||
"timeout": 120,
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
os.environ['OLLAMA_HOST'] = self.requester_cfg['base-url']
|
||||
self.client = ollama.AsyncClient(
|
||||
timeout=self.requester_cfg['timeout']
|
||||
)
|
||||
os.environ["OLLAMA_HOST"] = self.requester_cfg["base_url"]
|
||||
self.client = ollama.AsyncClient(timeout=self.requester_cfg["timeout"])
|
||||
|
||||
async def _req(self,
|
||||
args: dict,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
return await self.client.chat(
|
||||
**args
|
||||
)
|
||||
async def _req(
|
||||
self,
|
||||
args: dict,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
return await self.client.chat(**args)
|
||||
|
||||
async def _closure(self, query: core_entities.Query, req_messages: list[dict], use_model: entities.LLMModelInfo,
|
||||
user_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {}) -> llm_entities.Message:
|
||||
args: Any = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
async def _closure(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
user_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> llm_entities.Message:
|
||||
args = extra_args.copy()
|
||||
args["model"] = use_model.model_entity.name
|
||||
|
||||
messages: list[dict] = req_messages.copy()
|
||||
for msg in messages:
|
||||
if 'content' in msg and isinstance(msg["content"], list):
|
||||
if "content" in msg and isinstance(msg["content"], list):
|
||||
text_content: list = []
|
||||
image_urls: list = []
|
||||
for me in msg["content"]:
|
||||
@@ -58,12 +61,16 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
text_content.append(me["text"])
|
||||
elif me["type"] == "image_base64":
|
||||
image_urls.append(me["image_base64"])
|
||||
|
||||
|
||||
msg["content"] = "\n".join(text_content)
|
||||
msg["images"] = [url.split(',')[1] for url in image_urls]
|
||||
if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
|
||||
for tool_call in msg['tool_calls']:
|
||||
tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments'])
|
||||
msg["images"] = [url.split(",")[1] for url in image_urls]
|
||||
if (
|
||||
"tool_calls" in msg
|
||||
): # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
|
||||
for tool_call in msg["tool_calls"]:
|
||||
tool_call["function"]["arguments"] = json.loads(
|
||||
tool_call["function"]["arguments"]
|
||||
)
|
||||
args["messages"] = messages
|
||||
|
||||
args["tools"] = []
|
||||
@@ -77,8 +84,8 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
return message
|
||||
|
||||
async def _make_msg(
|
||||
self,
|
||||
chat_completions: ollama.ChatResponse) -> llm_entities.Message:
|
||||
self, chat_completions: ollama.ChatResponse
|
||||
) -> llm_entities.Message:
|
||||
message: ollama.Message = chat_completions.message
|
||||
if message is None:
|
||||
raise ValueError("chat_completions must contain a 'message' field")
|
||||
@@ -86,43 +93,51 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
ret_msg: llm_entities.Message = None
|
||||
|
||||
if message.content is not None:
|
||||
ret_msg = llm_entities.Message(
|
||||
role="assistant",
|
||||
content=message.content
|
||||
)
|
||||
ret_msg = llm_entities.Message(role="assistant", content=message.content)
|
||||
if message.tool_calls is not None and len(message.tool_calls) > 0:
|
||||
tool_calls: list[llm_entities.ToolCall] = []
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
tool_calls.append(llm_entities.ToolCall(
|
||||
id=uuid.uuid4().hex,
|
||||
type="function",
|
||||
function=llm_entities.FunctionCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=json.dumps(tool_call.function.arguments)
|
||||
tool_calls.append(
|
||||
llm_entities.ToolCall(
|
||||
id=uuid.uuid4().hex,
|
||||
type="function",
|
||||
function=llm_entities.FunctionCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=json.dumps(tool_call.function.arguments),
|
||||
),
|
||||
)
|
||||
))
|
||||
)
|
||||
ret_msg.tool_calls = tool_calls
|
||||
|
||||
return ret_msg
|
||||
|
||||
async def call(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: entities.LLMModelInfo,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
async def invoke_llm(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: requester.RuntimeLLMModel,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> llm_entities.Message:
|
||||
req_messages: list = []
|
||||
for m in messages:
|
||||
msg_dict: dict = m.dict(exclude_none=True)
|
||||
content: Any = msg_dict.get("content")
|
||||
if isinstance(content, list):
|
||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
||||
if all(
|
||||
isinstance(part, dict) and part.get("type") == "text"
|
||||
for part in content
|
||||
):
|
||||
msg_dict["content"] = "\n".join(part["text"] for part in content)
|
||||
req_messages.append(msg_dict)
|
||||
try:
|
||||
return await self._closure(query, req_messages, model, funcs, extra_args)
|
||||
return await self._closure(
|
||||
query=query,
|
||||
req_messages=req_messages,
|
||||
use_model=model,
|
||||
use_funcs=funcs,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
raise errors.RequesterError("请求超时")
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: Ollama
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -14,6 +14,6 @@ class SiliconFlowChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://api.siliconflow.cn/v1',
|
||||
'base_url': 'https://api.siliconflow.cn/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: 硅基流动
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -14,6 +14,6 @@ class VolcArkChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://ark.cn-beijing.volces.com/api/v3',
|
||||
'base_url': 'https://ark.cn-beijing.volces.com/api/v3',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: 火山方舟
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -14,6 +14,6 @@ class XaiChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://api.x.ai/v1',
|
||||
'base_url': 'https://api.x.ai/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: xAI
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -14,6 +14,6 @@ class ZhipuAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://open.bigmodel.cn/api/paas/v4',
|
||||
'base_url': 'https://open.bigmodel.cn/api/paas/v4',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: 智谱 AI
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -27,11 +27,11 @@ class RequestRunner(abc.ABC):
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
pipeline_config: dict
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
@abc.abstractmethod
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import runner
|
||||
from ..core import app
|
||||
|
||||
from .runners import localagent
|
||||
from .runners import difysvapi
|
||||
from .runners import dashscopeapi
|
||||
|
||||
class RunnerManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
using_runner: runner.RequestRunner
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
for r in runner.preregistered_runners:
|
||||
if r.name == self.ap.provider_cfg.data['runner']:
|
||||
self.using_runner = r(self.ap)
|
||||
await self.using_runner.initialize()
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"未找到请求运行器: {self.ap.provider_cfg.data['runner']}")
|
||||
|
||||
def get_runner(self) -> runner.RequestRunner:
|
||||
return self.using_runner
|
||||
@@ -8,7 +8,7 @@ import re
|
||||
import dashscope
|
||||
|
||||
from .. import runner
|
||||
from ...core import entities as core_entities
|
||||
from ...core import app, entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
from ...utils import image
|
||||
|
||||
@@ -29,12 +29,14 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
app_id: str # 应用ID
|
||||
api_key: str # API Key
|
||||
references_quote: str # 引用资料提示(当展示回答来源功能开启时,这个变量会作为引用资料名前的提示,可在provider.json中配置)
|
||||
biz_params: dict = {} # 工作流应用参数(仅在工作流应用中生效)
|
||||
|
||||
async def initialize(self):
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
"""初始化"""
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
valid_app_types = ["agent", "workflow"]
|
||||
self.app_type = self.ap.provider_cfg.data["dashscope-app-api"]["app-type"]
|
||||
self.app_type = self.pipeline_config["ai"]["dashscope-app-api"]["app-type"]
|
||||
#检查配置文件中使用的应用类型是否支持
|
||||
if (self.app_type not in valid_app_types):
|
||||
raise DashscopeAPIError(
|
||||
@@ -42,10 +44,9 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
)
|
||||
|
||||
#初始化Dashscope 参数配置
|
||||
self.app_id = self.ap.provider_cfg.data["dashscope-app-api"][self.app_type]["app-id"]
|
||||
self.api_key = self.ap.provider_cfg.data["dashscope-app-api"]["api-key"]
|
||||
self.references_quote = self.ap.provider_cfg.data["dashscope-app-api"][self.app_type]["references_quote"]
|
||||
self.biz_params = self.ap.provider_cfg.data["dashscope-app-api"]["workflow"]["biz_params"]
|
||||
self.app_id = self.pipeline_config["ai"]["dashscope-app-api"]["app-id"]
|
||||
self.api_key = self.pipeline_config["ai"]["dashscope-app-api"]["api-key"]
|
||||
self.references_quote = self.pipeline_config["ai"]["dashscope-app-api"]["references_quote"]
|
||||
|
||||
def _replace_references(self, text, references_dict):
|
||||
"""阿里云百炼平台的自定义应用支持资料引用,此函数可以将引用标签替换为参考资料"""
|
||||
@@ -169,7 +170,6 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
|
||||
biz_params = {}
|
||||
biz_params.update(self.biz_params)
|
||||
biz_params.update(query.variables)
|
||||
|
||||
#发送对话请求
|
||||
@@ -220,21 +220,19 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
content=pending_content,
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def run(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行"""
|
||||
if self.ap.provider_cfg.data["dashscope-app-api"]["app-type"] == "agent":
|
||||
if self.app_type == "agent":
|
||||
async for msg in self._agent_messages(query):
|
||||
yield msg
|
||||
elif self.ap.provider_cfg.data["dashscope-app-api"]["app-type"] == "workflow":
|
||||
elif self.app_type == "workflow":
|
||||
async for msg in self._workflow_messages(query):
|
||||
yield msg
|
||||
else:
|
||||
raise DashscopeAPIError(
|
||||
f"不支持的 Dashscope 应用类型: {self.ap.provider_cfg.data['dashscope-app-api']['app-type']}"
|
||||
f"不支持的 Dashscope 应用类型: {self.app_type}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import datetime
|
||||
import aiohttp
|
||||
|
||||
from .. import runner
|
||||
from ...core import entities as core_entities
|
||||
from ...core import app, entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
from ...utils import image
|
||||
|
||||
@@ -23,24 +23,24 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
dify_client: client.AsyncDifyServiceClient
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化"""
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
valid_app_types = ["chat", "agent", "workflow"]
|
||||
if (
|
||||
self.ap.provider_cfg.data["dify-service-api"]["app-type"]
|
||||
self.pipeline_config["ai"]["dify-service-api"]["app-type"]
|
||||
not in valid_app_types
|
||||
):
|
||||
raise errors.DifyAPIError(
|
||||
f"不支持的 Dify 应用类型: {self.ap.provider_cfg.data['dify-service-api']['app-type']}"
|
||||
f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}"
|
||||
)
|
||||
|
||||
api_key = self.ap.provider_cfg.data["dify-service-api"][
|
||||
self.ap.provider_cfg.data["dify-service-api"]["app-type"]
|
||||
]["api-key"]
|
||||
api_key = self.pipeline_config["ai"]["dify-service-api"]["api-key"]
|
||||
|
||||
self.dify_client = client.AsyncDifyServiceClient(
|
||||
api_key=api_key,
|
||||
base_url=self.ap.provider_cfg.data["dify-service-api"]["base-url"],
|
||||
base_url=self.pipeline_config["ai"]["dify-service-api"]["base-url"],
|
||||
)
|
||||
|
||||
def _try_convert_thinking(self, resp_text: str) -> str:
|
||||
@@ -48,13 +48,13 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
if not resp_text.startswith("<details style=\"color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;\" open> <summary> Thinking... </summary>"):
|
||||
return resp_text
|
||||
|
||||
if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "original":
|
||||
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "original":
|
||||
return resp_text
|
||||
|
||||
if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "remove":
|
||||
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "remove":
|
||||
return re.sub(r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>.*?</details>', '', resp_text, flags=re.DOTALL)
|
||||
|
||||
if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "plain":
|
||||
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "plain":
|
||||
pattern = r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>(.*?)</details>'
|
||||
thinking_text = re.search(pattern, resp_text, flags=re.DOTALL)
|
||||
content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL)
|
||||
@@ -121,7 +121,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
|
||||
conversation_id=cov_id,
|
||||
files=files,
|
||||
timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"],
|
||||
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
|
||||
):
|
||||
self.ap.logger.debug("dify-chat-chunk: " + str(chunk))
|
||||
|
||||
@@ -177,7 +177,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
response_mode="streaming",
|
||||
conversation_id=cov_id,
|
||||
files=files,
|
||||
timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"],
|
||||
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
|
||||
):
|
||||
self.ap.logger.debug("dify-agent-chunk: " + str(chunk))
|
||||
|
||||
@@ -264,7 +264,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
inputs=inputs,
|
||||
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
|
||||
files=files,
|
||||
timeout=self.ap.provider_cfg.data["dify-service-api"]["workflow"]["timeout"],
|
||||
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
|
||||
):
|
||||
self.ap.logger.debug("dify-workflow-chunk: " + str(chunk))
|
||||
if chunk["event"] in ignored_events:
|
||||
@@ -301,11 +301,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role="assistant",
|
||||
content=chunk["data"]["outputs"][
|
||||
self.ap.provider_cfg.data["dify-service-api"]["workflow"][
|
||||
"output-key"
|
||||
]
|
||||
],
|
||||
content=chunk["data"]["outputs"]["summary"],
|
||||
)
|
||||
|
||||
yield msg
|
||||
@@ -314,16 +310,16 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求"""
|
||||
if self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "chat":
|
||||
if self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "chat":
|
||||
async for msg in self._chat_messages(query):
|
||||
yield msg
|
||||
elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "agent":
|
||||
elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "agent":
|
||||
async for msg in self._agent_chat_messages(query):
|
||||
yield msg
|
||||
elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "workflow":
|
||||
elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "workflow":
|
||||
async for msg in self._workflow_messages(query):
|
||||
yield msg
|
||||
else:
|
||||
raise errors.DifyAPIError(
|
||||
f"不支持的 Dify 应用类型: {self.ap.provider_cfg.data['dify-service-api']['app-type']}"
|
||||
f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}"
|
||||
)
|
||||
|
||||
@@ -16,14 +16,12 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求
|
||||
"""
|
||||
await query.use_model.requester.preprocess(query)
|
||||
|
||||
pending_tool_calls = []
|
||||
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
||||
|
||||
# 首次请求
|
||||
msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs)
|
||||
msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
@@ -61,7 +59,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
req_messages.append(err_msg)
|
||||
|
||||
# 处理完所有调用,再次请求
|
||||
msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs)
|
||||
msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class SessionManager:
|
||||
self.session_list.append(session)
|
||||
return session
|
||||
|
||||
async def get_conversation(self, session: core_entities.Session) -> core_entities.Conversation:
|
||||
async def get_conversation(self, query: core_entities.Query, session: core_entities.Session) -> core_entities.Conversation:
|
||||
"""获取对话或创建对话"""
|
||||
|
||||
if not session.conversations:
|
||||
@@ -51,7 +51,9 @@ class SessionManager:
|
||||
conversation = core_entities.Conversation(
|
||||
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
|
||||
messages=[],
|
||||
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['model']),
|
||||
use_llm_model=await self.ap.model_mgr.get_model_by_uuid(
|
||||
query.pipeline_config['ai']['local-agent']['model']
|
||||
),
|
||||
use_funcs=await self.ap.tool_mgr.get_all_functions(
|
||||
plugin_enabled=True,
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user