feat: preliminarily implement pipeline invoking

This commit is contained in:
Junyan Qin
2025-03-29 17:50:45 +08:00
parent d01eadc70f
commit 9f15ab5000
57 changed files with 384 additions and 421 deletions
+11 -23
View File
@@ -2,6 +2,7 @@ from __future__ import annotations
import typing
import sqlalchemy
import pydantic.v1 as pydantic
from . import entities, requester
from ...core import app
@@ -16,23 +17,6 @@ from .requesters import bailianchatcmpl, chatcmpl, anthropicmsgs, moonshotchatcm
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
class RuntimeLLMModel:
"""运行时模型"""
model_entity: persistence_model.LLMModel
"""模型数据"""
token_mgr: token.TokenManager
"""api key管理器"""
requester: requester.LLMAPIRequester
"""请求器实例"""
def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: requester.LLMAPIRequester):
self.model_entity = model_entity
self.token_mgr = token_mgr
self.requester = requester
class ModelManager:
"""模型管理器"""
@@ -47,7 +31,7 @@ class ModelManager:
ap: app.Application
llm_models: list[RuntimeLLMModel]
llm_models: list[requester.RuntimeLLMModel]
requester_components: list[engine.Component]
@@ -99,16 +83,20 @@ class ModelManager:
elif isinstance(model_info, dict):
model_info = persistence_model.LLMModel(**model_info)
runtime_llm_model = RuntimeLLMModel(
requester_inst = self.requester_dict[model_info.requester](
ap=self.ap,
config=model_info.requester_config
)
await requester_inst.initialize()
runtime_llm_model = requester.RuntimeLLMModel(
model_entity=model_info,
token_mgr=token.TokenManager(
name=model_info.uuid,
tokens=model_info.api_keys,
),
requester=self.requester_dict[model_info.requester](
ap=self.ap,
config=model_info.requester_config
)
requester=requester_inst
)
self.llm_models.append(runtime_llm_model)
+23 -14
View File
@@ -6,8 +6,27 @@ import typing
from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
from . import entities as modelmgr_entities
from ..tools import entities as tools_entities
from ...entity.persistence import model as persistence_model
from . import token
class RuntimeLLMModel:
"""运行时模型"""
model_entity: persistence_model.LLMModel
"""模型数据"""
token_mgr: token.TokenManager
"""api key管理器"""
requester: LLMAPIRequester
"""请求器实例"""
def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: LLMAPIRequester):
self.model_entity = model_entity
self.token_mgr = token_mgr
self.requester = requester
class LLMAPIRequester(metaclass=abc.ABCMeta):
@@ -31,21 +50,11 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
async def initialize(self):
pass
async def preprocess(
self,
query: core_entities.Query,
):
"""预处理
在这里处理特定API对Query对象的兼容性问题。
"""
pass
@abc.abstractmethod
async def call(
async def invoke_llm(
self,
query: core_entities.Query,
model: modelmgr_entities.LLMModelInfo,
model: RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
@@ -53,7 +62,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
"""调用API
Args:
model (modelmgr_entities.LLMModelInfo): 使用的模型信息
model (RuntimeLLMModel): 使用的模型信息
messages (typing.List[llm_entities.Message]): 消息对象列表
funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None.
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
@@ -24,16 +24,16 @@ class AnthropicMessages(requester.LLMAPIRequester):
client: anthropic.AsyncAnthropic
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.anthropic.com/v1',
'base_url': 'https://api.anthropic.com/v1',
'timeout': 120,
}
async def initialize(self):
httpx_client = anthropic._base_client.AsyncHttpxClientWrapper(
base_url=self.ap.provider_cfg.data['requester']['anthropic-messages']['base-url'],
base_url=self.requester_cfg['base_url'],
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=typing.cast(httpx.Timeout, self.ap.provider_cfg.data['requester']['anthropic-messages']['timeout']),
timeout=typing.cast(httpx.Timeout, self.requester_cfg['timeout']),
limits=anthropic._constants.DEFAULT_CONNECTION_LIMITS,
follow_redirects=True,
trust_env=True,
@@ -44,17 +44,18 @@ class AnthropicMessages(requester.LLMAPIRequester):
http_client=httpx_client,
)
async def call(
async def invoke_llm(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = model.token_mgr.get_token()
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
args["model"] = model.name if model.model_name is None else model.model_name
args = extra_args.copy()
args["model"] = model.model_entity.name
# 处理消息
@@ -7,7 +7,7 @@ metadata:
zh_CN: Anthropic
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL
@@ -14,6 +14,6 @@ class BailianChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
'base_url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
'timeout': 120,
}
@@ -7,7 +7,7 @@ metadata:
zh_CN: 阿里云百炼
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL
+7 -9
View File
@@ -26,7 +26,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
"base-url": "https://api.openai.com/v1",
"base_url": "https://api.openai.com/v1",
"timeout": 120,
}
@@ -34,7 +34,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
self.client = openai.AsyncClient(
api_key="",
base_url=self.requester_cfg["base-url"],
base_url=self.requester_cfg["base_url"],
timeout=self.requester_cfg["timeout"],
http_client=httpx.AsyncClient(
trust_env=True, timeout=self.requester_cfg["timeout"]
@@ -65,16 +65,14 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {}, # TODO: 所有的args都改为从此参数读取
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg["args"].copy()
args["model"] = (
use_model.name if use_model.model_name is None else use_model.model_name
)
args = extra_args.copy()
args["model"] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
@@ -104,10 +102,10 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
return message
async def call(
async def invoke_llm(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
@@ -7,7 +7,7 @@ metadata:
zh_CN: OpenAI
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL
@@ -13,7 +13,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Deepseek ChatCompletion API 请求器"""
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.deepseek.com',
'base_url': 'https://api.deepseek.com',
'timeout': 120,
}
@@ -21,14 +21,14 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
args = extra_args.copy()
args["model"] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
@@ -7,7 +7,7 @@ metadata:
zh_CN: 深度求索
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL
@@ -18,7 +18,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Gitee AI ChatCompletions API 请求器"""
default_config: dict[str, typing.Any] = {
'base-url': 'https://ai.gitee.com/v1',
'base_url': 'https://ai.gitee.com/v1',
'timeout': 120,
}
@@ -26,14 +26,14 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
args = extra_args.copy()
args["model"] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
@@ -7,7 +7,7 @@ metadata:
zh_CN: Gitee AI
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL
@@ -14,6 +14,6 @@ class LmStudioChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'http://127.0.0.1:1234/v1',
'base_url': 'http://127.0.0.1:1234/v1',
'timeout': 120,
}
@@ -7,7 +7,7 @@ metadata:
zh_CN: LM Studio
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL
@@ -15,7 +15,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Moonshot ChatCompletion API 请求器"""
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.moonshot.cn/v1',
'base_url': 'https://api.moonshot.cn/v1',
'timeout': 120,
}
@@ -23,14 +23,14 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
args = extra_args.copy()
args["model"] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
@@ -7,7 +7,7 @@ metadata:
zh_CN: 月之暗面
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL
+61 -46
View File
@@ -22,35 +22,38 @@ REQUESTER_NAME: str = "ollama-chat"
class OllamaChatCompletions(requester.LLMAPIRequester):
"""Ollama平台 ChatCompletion API请求器"""
client: ollama.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'http://127.0.0.1:11434',
'timeout': 120,
"base_url": "http://127.0.0.1:11434",
"timeout": 120,
}
async def initialize(self):
os.environ['OLLAMA_HOST'] = self.requester_cfg['base-url']
self.client = ollama.AsyncClient(
timeout=self.requester_cfg['timeout']
)
os.environ["OLLAMA_HOST"] = self.requester_cfg["base_url"]
self.client = ollama.AsyncClient(timeout=self.requester_cfg["timeout"])
async def _req(self,
args: dict,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
return await self.client.chat(
**args
)
async def _req(
self,
args: dict,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
return await self.client.chat(**args)
async def _closure(self, query: core_entities.Query, req_messages: list[dict], use_model: entities.LLMModelInfo,
user_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {}) -> llm_entities.Message:
args: Any = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
async def _closure(
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: requester.RuntimeLLMModel,
user_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
args = extra_args.copy()
args["model"] = use_model.model_entity.name
messages: list[dict] = req_messages.copy()
for msg in messages:
if 'content' in msg and isinstance(msg["content"], list):
if "content" in msg and isinstance(msg["content"], list):
text_content: list = []
image_urls: list = []
for me in msg["content"]:
@@ -58,12 +61,16 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
text_content.append(me["text"])
elif me["type"] == "image_base64":
image_urls.append(me["image_base64"])
msg["content"] = "\n".join(text_content)
msg["images"] = [url.split(',')[1] for url in image_urls]
if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
for tool_call in msg['tool_calls']:
tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments'])
msg["images"] = [url.split(",")[1] for url in image_urls]
if (
"tool_calls" in msg
): # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
for tool_call in msg["tool_calls"]:
tool_call["function"]["arguments"] = json.loads(
tool_call["function"]["arguments"]
)
args["messages"] = messages
args["tools"] = []
@@ -77,8 +84,8 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
return message
async def _make_msg(
self,
chat_completions: ollama.ChatResponse) -> llm_entities.Message:
self, chat_completions: ollama.ChatResponse
) -> llm_entities.Message:
message: ollama.Message = chat_completions.message
if message is None:
raise ValueError("chat_completions must contain a 'message' field")
@@ -86,43 +93,51 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
ret_msg: llm_entities.Message = None
if message.content is not None:
ret_msg = llm_entities.Message(
role="assistant",
content=message.content
)
ret_msg = llm_entities.Message(role="assistant", content=message.content)
if message.tool_calls is not None and len(message.tool_calls) > 0:
tool_calls: list[llm_entities.ToolCall] = []
for tool_call in message.tool_calls:
tool_calls.append(llm_entities.ToolCall(
id=uuid.uuid4().hex,
type="function",
function=llm_entities.FunctionCall(
name=tool_call.function.name,
arguments=json.dumps(tool_call.function.arguments)
tool_calls.append(
llm_entities.ToolCall(
id=uuid.uuid4().hex,
type="function",
function=llm_entities.FunctionCall(
name=tool_call.function.name,
arguments=json.dumps(tool_call.function.arguments),
),
)
))
)
ret_msg.tool_calls = tool_calls
return ret_msg
async def call(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
async def invoke_llm(
self,
query: core_entities.Query,
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
req_messages: list = []
for m in messages:
msg_dict: dict = m.dict(exclude_none=True)
content: Any = msg_dict.get("content")
if isinstance(content, list):
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
if all(
isinstance(part, dict) and part.get("type") == "text"
for part in content
):
msg_dict["content"] = "\n".join(part["text"] for part in content)
req_messages.append(msg_dict)
try:
return await self._closure(query, req_messages, model, funcs, extra_args)
return await self._closure(
query=query,
req_messages=req_messages,
use_model=model,
use_funcs=funcs,
extra_args=extra_args,
)
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
raise errors.RequesterError("请求超时")
@@ -7,7 +7,7 @@ metadata:
zh_CN: Ollama
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL
@@ -14,6 +14,6 @@ class SiliconFlowChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.siliconflow.cn/v1',
'base_url': 'https://api.siliconflow.cn/v1',
'timeout': 120,
}
@@ -7,7 +7,7 @@ metadata:
zh_CN: 硅基流动
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL
@@ -14,6 +14,6 @@ class VolcArkChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'https://ark.cn-beijing.volces.com/api/v3',
'base_url': 'https://ark.cn-beijing.volces.com/api/v3',
'timeout': 120,
}
@@ -7,7 +7,7 @@ metadata:
zh_CN: 火山方舟
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL
@@ -14,6 +14,6 @@ class XaiChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.x.ai/v1',
'base_url': 'https://api.x.ai/v1',
'timeout': 120,
}
@@ -7,7 +7,7 @@ metadata:
zh_CN: xAI
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL
@@ -14,6 +14,6 @@ class ZhipuAIChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'https://open.bigmodel.cn/api/paas/v4',
'base_url': 'https://open.bigmodel.cn/api/paas/v4',
'timeout': 120,
}
@@ -7,7 +7,7 @@ metadata:
zh_CN: 智谱 AI
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL