mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-09 23:36:02 +00:00
style: introduce ruff as linter and formatter (#1356)
* style: remove necessary imports * style: fix F841 * style: fix F401 * style: fix F811 * style: fix E402 * style: fix E721 * style: fix E722 * style: fix E722 * style: fix F541 * style: ruff format * style: all passed * style: add ruff in deps * style: more ignores in ruff.toml * style: add pre-commit
This commit is contained in:
committed by
GitHub
parent
09e70d70e9
commit
209f16af76
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import enum
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from pkg.provider import entities
|
||||
@@ -32,7 +31,6 @@ class ImageURLContentObject(pydantic.BaseModel):
|
||||
|
||||
|
||||
class ContentElement(pydantic.BaseModel):
|
||||
|
||||
type: str
|
||||
"""内容类型"""
|
||||
|
||||
@@ -57,7 +55,7 @@ class ContentElement(pydantic.BaseModel):
|
||||
@classmethod
|
||||
def from_image_url(cls, image_url: str):
|
||||
return cls(type='image_url', image_url=ImageURLContentObject(url=image_url))
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_image_base64(cls, image_base64: str):
|
||||
return cls(type='image_base64', image_base64=image_base64)
|
||||
@@ -82,15 +80,19 @@ class Message(pydantic.BaseModel):
|
||||
|
||||
def readable_str(self) -> str:
|
||||
if self.content is not None:
|
||||
return str(self.role) + ": " + str(self.get_content_platform_message_chain())
|
||||
return (
|
||||
str(self.role) + ': ' + str(self.get_content_platform_message_chain())
|
||||
)
|
||||
elif self.tool_calls is not None:
|
||||
return f'调用工具: {self.tool_calls[0].id}'
|
||||
else:
|
||||
return '未知消息'
|
||||
|
||||
def get_content_platform_message_chain(self, prefix_text: str="") -> platform_message.MessageChain | None:
|
||||
def get_content_platform_message_chain(
|
||||
self, prefix_text: str = ''
|
||||
) -> platform_message.MessageChain | None:
|
||||
"""将内容转换为平台消息 MessageChain 对象
|
||||
|
||||
|
||||
Args:
|
||||
prefix_text (str): 首个文字组件的前缀文本
|
||||
"""
|
||||
@@ -98,21 +100,22 @@ class Message(pydantic.BaseModel):
|
||||
if self.content is None:
|
||||
return None
|
||||
elif isinstance(self.content, str):
|
||||
return platform_message.MessageChain([platform_message.Plain(prefix_text+self.content)])
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Plain(prefix_text + self.content)]
|
||||
)
|
||||
elif isinstance(self.content, list):
|
||||
mc = []
|
||||
for ce in self.content:
|
||||
if ce.type == 'text':
|
||||
mc.append(platform_message.Plain(ce.text))
|
||||
elif ce.type == 'image_url':
|
||||
if ce.image_url.url.startswith("http"):
|
||||
if ce.image_url.url.startswith('http'):
|
||||
mc.append(platform_message.Image(url=ce.image_url.url))
|
||||
else: # base64
|
||||
|
||||
b64_str = ce.image_url.url
|
||||
|
||||
if b64_str.startswith("data:"):
|
||||
b64_str = b64_str.split(",")[1]
|
||||
if b64_str.startswith('data:'):
|
||||
b64_str = b64_str.split(',')[1]
|
||||
|
||||
mc.append(platform_message.Image(base64=b64_str))
|
||||
|
||||
@@ -120,7 +123,7 @@ class Message(pydantic.BaseModel):
|
||||
if prefix_text:
|
||||
for i, c in enumerate(mc):
|
||||
if isinstance(c, platform_message.Plain):
|
||||
mc[i] = platform_message.Plain(prefix_text+c.text)
|
||||
mc[i] = platform_message.Plain(prefix_text + c.text)
|
||||
break
|
||||
else:
|
||||
mc.insert(0, platform_message.Plain(prefix_text))
|
||||
|
||||
@@ -2,4 +2,4 @@ class RequesterError(Exception):
|
||||
"""Base class for all Requester errors."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__("模型请求失败: "+message)
|
||||
super().__init__('模型请求失败: ' + message)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import sqlalchemy
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from . import entities, requester
|
||||
from ...core import app
|
||||
@@ -12,10 +10,8 @@ from ..tools import entities as tools_entities
|
||||
from ...discover import engine
|
||||
from . import token
|
||||
from ...entity.persistence import model as persistence_model
|
||||
from .requesters import bailianchatcmpl, chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat, giteeaichatcmpl, volcarkchatcmpl, xaichatcmpl, zhipuaichatcmpl, lmstudiochatcmpl, siliconflowchatcmpl, volcarkchatcmpl
|
||||
|
||||
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
|
||||
|
||||
FETCH_MODEL_LIST_URL = 'https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list'
|
||||
|
||||
|
||||
class ModelManager:
|
||||
@@ -36,7 +32,7 @@ class ModelManager:
|
||||
requester_components: list[engine.Component]
|
||||
|
||||
requester_dict: dict[str, type[requester.LLMAPIRequester]] # cache
|
||||
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.model_list = []
|
||||
@@ -45,14 +41,18 @@ class ModelManager:
|
||||
self.llm_models = []
|
||||
self.requester_components = []
|
||||
self.requester_dict = {}
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester')
|
||||
self.requester_components = self.ap.discover.get_components_by_kind(
|
||||
'LLMAPIRequester'
|
||||
)
|
||||
|
||||
# forge requester class dict
|
||||
requester_dict: dict[str, type[requester.LLMAPIRequester]] = {}
|
||||
for component in self.requester_components:
|
||||
requester_dict[component.metadata.name] = component.get_python_component_class()
|
||||
requester_dict[component.metadata.name] = (
|
||||
component.get_python_component_class()
|
||||
)
|
||||
|
||||
self.requester_dict = requester_dict
|
||||
|
||||
@@ -74,18 +74,22 @@ class ModelManager:
|
||||
# load models
|
||||
for llm_model in llm_models:
|
||||
await self.load_llm_model(llm_model)
|
||||
|
||||
async def load_llm_model(self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict):
|
||||
|
||||
async def load_llm_model(
|
||||
self,
|
||||
model_info: persistence_model.LLMModel
|
||||
| sqlalchemy.Row[persistence_model.LLMModel]
|
||||
| dict,
|
||||
):
|
||||
"""加载模型"""
|
||||
|
||||
|
||||
if isinstance(model_info, sqlalchemy.Row):
|
||||
model_info = persistence_model.LLMModel(**model_info._mapping)
|
||||
elif isinstance(model_info, dict):
|
||||
model_info = persistence_model.LLMModel(**model_info)
|
||||
|
||||
requester_inst = self.requester_dict[model_info.requester](
|
||||
ap=self.ap,
|
||||
config=model_info.requester_config
|
||||
ap=self.ap, config=model_info.requester_config
|
||||
)
|
||||
|
||||
await requester_inst.initialize()
|
||||
@@ -96,24 +100,23 @@ class ModelManager:
|
||||
name=model_info.uuid,
|
||||
tokens=model_info.api_keys,
|
||||
),
|
||||
requester=requester_inst
|
||||
requester=requester_inst,
|
||||
)
|
||||
self.llm_models.append(runtime_llm_model)
|
||||
|
||||
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: # deprecated
|
||||
"""通过名称获取模型
|
||||
"""
|
||||
"""通过名称获取模型"""
|
||||
for model in self.model_list:
|
||||
if model.name == name:
|
||||
return model
|
||||
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
|
||||
|
||||
raise ValueError(f'无法确定模型 {name} 的信息,请在元数据中配置')
|
||||
|
||||
async def get_model_by_uuid(self, uuid: str) -> entities.LLMModelInfo:
|
||||
"""通过uuid获取模型"""
|
||||
for model in self.llm_models:
|
||||
if model.model_entity.uuid == uuid:
|
||||
return model
|
||||
raise ValueError(f"model {uuid} not found")
|
||||
raise ValueError(f'model {uuid} not found')
|
||||
|
||||
async def remove_llm_model(self, model_uuid: str):
|
||||
"""移除模型"""
|
||||
@@ -124,10 +127,7 @@ class ModelManager:
|
||||
|
||||
def get_available_requesters_info(self) -> list[dict]:
|
||||
"""获取所有可用的请求器"""
|
||||
return [
|
||||
component.to_plain_dict()
|
||||
for component in self.requester_components
|
||||
]
|
||||
return [component.to_plain_dict() for component in self.requester_components]
|
||||
|
||||
def get_available_requester_info_by_name(self, name: str) -> dict | None:
|
||||
"""通过名称获取请求器信息"""
|
||||
@@ -135,8 +135,10 @@ class ModelManager:
|
||||
if component.metadata.name == name:
|
||||
return component.to_plain_dict()
|
||||
return None
|
||||
|
||||
def get_available_requester_manifest_by_name(self, name: str) -> engine.Component | None:
|
||||
|
||||
def get_available_requester_manifest_by_name(
|
||||
self, name: str
|
||||
) -> engine.Component | None:
|
||||
"""通过名称获取请求器清单"""
|
||||
for component in self.requester_components:
|
||||
if component.metadata.name == name:
|
||||
@@ -151,4 +153,3 @@ class ModelManager:
|
||||
funcs: list[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
pass
|
||||
|
||||
|
||||
@@ -22,16 +22,21 @@ class RuntimeLLMModel:
|
||||
|
||||
requester: LLMAPIRequester
|
||||
"""请求器实例"""
|
||||
|
||||
def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: LLMAPIRequester):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_entity: persistence_model.LLMModel,
|
||||
token_mgr: token.TokenManager,
|
||||
requester: LLMAPIRequester,
|
||||
):
|
||||
self.model_entity = model_entity
|
||||
self.token_mgr = token_mgr
|
||||
self.requester = requester
|
||||
|
||||
|
||||
class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""LLM API请求器
|
||||
"""
|
||||
"""LLM API请求器"""
|
||||
|
||||
name: str = None
|
||||
|
||||
ap: app.Application
|
||||
@@ -42,9 +47,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self, ap: app.Application, config: dict[str, typing.Any]):
|
||||
self.ap = ap
|
||||
self.requester_cfg = {
|
||||
**self.default_config
|
||||
}
|
||||
self.requester_cfg = {**self.default_config}
|
||||
self.requester_cfg.update(config)
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
@@ -2,16 +2,12 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import json
|
||||
import traceback
|
||||
import base64
|
||||
|
||||
import anthropic
|
||||
import httpx
|
||||
|
||||
from ....core import app
|
||||
from .. import entities, errors, requester
|
||||
from .. import errors, requester
|
||||
|
||||
from .. import entities, errors
|
||||
from ....core import entities as core_entities
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
@@ -29,7 +25,6 @@ class AnthropicMessages(requester.LLMAPIRequester):
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
httpx_client = anthropic._base_client.AsyncHttpxClientWrapper(
|
||||
base_url=self.requester_cfg['base_url'],
|
||||
# cast to a valid type because mypy doesn't understand our type narrowing
|
||||
@@ -40,7 +35,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
|
||||
)
|
||||
|
||||
self.client = anthropic.AsyncAnthropic(
|
||||
api_key="",
|
||||
api_key='',
|
||||
http_client=httpx_client,
|
||||
)
|
||||
|
||||
@@ -55,7 +50,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
|
||||
self.client.api_key = model.token_mgr.get_token()
|
||||
|
||||
args = extra_args.copy()
|
||||
args["model"] = model.model_entity.name
|
||||
args['model'] = model.model_entity.name
|
||||
|
||||
# 处理消息
|
||||
|
||||
@@ -63,14 +58,15 @@ class AnthropicMessages(requester.LLMAPIRequester):
|
||||
system_role_message = None
|
||||
|
||||
for i, m in enumerate(messages):
|
||||
if m.role == "system":
|
||||
if m.role == 'system':
|
||||
system_role_message = m
|
||||
|
||||
messages.pop(i)
|
||||
break
|
||||
|
||||
if isinstance(system_role_message, llm_entities.Message) \
|
||||
and isinstance(system_role_message.content, str):
|
||||
if isinstance(system_role_message, llm_entities.Message) and isinstance(
|
||||
system_role_message.content, str
|
||||
):
|
||||
args['system'] = system_role_message.content
|
||||
|
||||
req_messages = []
|
||||
@@ -79,67 +75,64 @@ class AnthropicMessages(requester.LLMAPIRequester):
|
||||
if m.role == 'tool':
|
||||
tool_call_id = m.tool_call_id
|
||||
|
||||
req_messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call_id,
|
||||
"content": m.content
|
||||
}
|
||||
]
|
||||
})
|
||||
req_messages.append(
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
{
|
||||
'type': 'tool_result',
|
||||
'tool_use_id': tool_call_id,
|
||||
'content': m.content,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
msg_dict = m.dict(exclude_none=True)
|
||||
|
||||
if isinstance(m.content, str) and m.content.strip() != "":
|
||||
msg_dict["content"] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": m.content
|
||||
}
|
||||
]
|
||||
if isinstance(m.content, str) and m.content.strip() != '':
|
||||
msg_dict['content'] = [{'type': 'text', 'text': m.content}]
|
||||
elif isinstance(m.content, list):
|
||||
|
||||
for i, ce in enumerate(m.content):
|
||||
|
||||
if ce.type == "image_base64":
|
||||
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
||||
if ce.type == 'image_base64':
|
||||
image_b64, image_format = await image.extract_b64_and_format(
|
||||
ce.image_base64
|
||||
)
|
||||
|
||||
alter_image_ele = {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": f"image/{image_format}",
|
||||
"data": image_b64
|
||||
}
|
||||
'type': 'image',
|
||||
'source': {
|
||||
'type': 'base64',
|
||||
'media_type': f'image/{image_format}',
|
||||
'data': image_b64,
|
||||
},
|
||||
}
|
||||
msg_dict["content"][i] = alter_image_ele
|
||||
msg_dict['content'][i] = alter_image_ele
|
||||
|
||||
if m.tool_calls:
|
||||
|
||||
for tool_call in m.tool_calls:
|
||||
msg_dict["content"].append({
|
||||
"type": "tool_use",
|
||||
"id": tool_call.id,
|
||||
"name": tool_call.function.name,
|
||||
"input": json.loads(tool_call.function.arguments)
|
||||
})
|
||||
msg_dict['content'].append(
|
||||
{
|
||||
'type': 'tool_use',
|
||||
'id': tool_call.id,
|
||||
'name': tool_call.function.name,
|
||||
'input': json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
del msg_dict["tool_calls"]
|
||||
del msg_dict['tool_calls']
|
||||
|
||||
req_messages.append(msg_dict)
|
||||
|
||||
|
||||
args["messages"] = req_messages
|
||||
|
||||
args['messages'] = req_messages
|
||||
|
||||
if funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_anthropic(funcs)
|
||||
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
args['tools'] = tools
|
||||
|
||||
try:
|
||||
# print(json.dumps(args, indent=4, ensure_ascii=False))
|
||||
@@ -149,23 +142,24 @@ class AnthropicMessages(requester.LLMAPIRequester):
|
||||
'content': '',
|
||||
'role': resp.role,
|
||||
}
|
||||
|
||||
|
||||
assert type(resp) is anthropic.types.message.Message
|
||||
|
||||
for block in resp.content:
|
||||
if block.type == 'thinking':
|
||||
args['content'] = '<think>' + block.thinking + '</think>\n' + args['content']
|
||||
args['content'] = (
|
||||
'<think>' + block.thinking + '</think>\n' + args['content']
|
||||
)
|
||||
elif block.type == 'text':
|
||||
args['content'] += block.text
|
||||
elif block.type == 'tool_use':
|
||||
assert type(block) is anthropic.types.tool_use_block.ToolUseBlock
|
||||
tool_call = llm_entities.ToolCall(
|
||||
id=block.id,
|
||||
type="function",
|
||||
type='function',
|
||||
function=llm_entities.FunctionCall(
|
||||
name=block.name,
|
||||
arguments=json.dumps(block.input)
|
||||
)
|
||||
name=block.name, arguments=json.dumps(block.input)
|
||||
),
|
||||
)
|
||||
if 'tool_calls' not in args:
|
||||
args['tool_calls'] = []
|
||||
|
||||
@@ -4,8 +4,6 @@ import typing
|
||||
import openai
|
||||
|
||||
from . import chatcmpl
|
||||
from .. import requester
|
||||
from ....core import app
|
||||
|
||||
|
||||
class BailianChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
|
||||
@@ -2,22 +2,15 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
import json
|
||||
import base64
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import openai
|
||||
import openai.types.chat.chat_completion as chat_completion
|
||||
import openai.types.chat.chat_completion_message_tool_call as chat_completion_message_tool_call
|
||||
import httpx
|
||||
import aiohttp
|
||||
import async_lru
|
||||
|
||||
from .. import entities, errors, requester
|
||||
from ....core import entities as core_entities, app
|
||||
from .. import errors, requester
|
||||
from ....core import entities as core_entities
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
from ....utils import image
|
||||
|
||||
|
||||
class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
@@ -26,18 +19,17 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
'base_url': 'https://api.openai.com/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
self.client = openai.AsyncClient(
|
||||
api_key="",
|
||||
base_url=self.requester_cfg["base_url"],
|
||||
timeout=self.requester_cfg["timeout"],
|
||||
api_key='',
|
||||
base_url=self.requester_cfg['base_url'],
|
||||
timeout=self.requester_cfg['timeout'],
|
||||
http_client=httpx.AsyncClient(
|
||||
trust_env=True, timeout=self.requester_cfg["timeout"]
|
||||
trust_env=True, timeout=self.requester_cfg['timeout']
|
||||
),
|
||||
)
|
||||
|
||||
@@ -54,8 +46,8 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
chatcmpl_message = chat_completion.choices[0].message.model_dump()
|
||||
|
||||
# 确保 role 字段存在且不为 None
|
||||
if "role" not in chatcmpl_message or chatcmpl_message["role"] is None:
|
||||
chatcmpl_message["role"] = "assistant"
|
||||
if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None:
|
||||
chatcmpl_message['role'] = 'assistant'
|
||||
|
||||
message = llm_entities.Message(**chatcmpl_message)
|
||||
|
||||
@@ -72,27 +64,27 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = extra_args.copy()
|
||||
args["model"] = use_model.model_entity.name
|
||||
args['model'] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
args['tools'] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages.copy()
|
||||
|
||||
# 检查vision
|
||||
for msg in messages:
|
||||
if "content" in msg and isinstance(msg["content"], list):
|
||||
for me in msg["content"]:
|
||||
if me["type"] == "image_base64":
|
||||
me["image_url"] = {"url": me["image_base64"]}
|
||||
me["type"] = "image_url"
|
||||
del me["image_base64"]
|
||||
if 'content' in msg and isinstance(msg['content'], list):
|
||||
for me in msg['content']:
|
||||
if me['type'] == 'image_base64':
|
||||
me['image_url'] = {'url': me['image_base64']}
|
||||
me['type'] = 'image_url'
|
||||
del me['image_base64']
|
||||
|
||||
args["messages"] = messages
|
||||
args['messages'] = messages
|
||||
|
||||
# 发送请求
|
||||
resp = await self._req(args)
|
||||
@@ -113,15 +105,15 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
|
||||
for m in messages:
|
||||
msg_dict = m.dict(exclude_none=True)
|
||||
content = msg_dict.get("content")
|
||||
content = msg_dict.get('content')
|
||||
if isinstance(content, list):
|
||||
# 检查 content 列表中是否每个部分都是文本
|
||||
if all(
|
||||
isinstance(part, dict) and part.get("type") == "text"
|
||||
isinstance(part, dict) and part.get('type') == 'text'
|
||||
for part in content
|
||||
):
|
||||
# 将所有文本部分合并为一个字符串
|
||||
msg_dict["content"] = "\n".join(part["text"] for part in content)
|
||||
msg_dict['content'] = '\n'.join(part['text'] for part in content)
|
||||
req_messages.append(msg_dict)
|
||||
|
||||
try:
|
||||
@@ -133,17 +125,17 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
extra_args=extra_args,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError("请求超时")
|
||||
raise errors.RequesterError('请求超时')
|
||||
except openai.BadRequestError as e:
|
||||
if "context_length_exceeded" in e.message:
|
||||
raise errors.RequesterError(f"上文过长,请重置会话: {e.message}")
|
||||
if 'context_length_exceeded' in e.message:
|
||||
raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
|
||||
else:
|
||||
raise errors.RequesterError(f"请求参数错误: {e.message}")
|
||||
raise errors.RequesterError(f'请求参数错误: {e.message}')
|
||||
except openai.AuthenticationError as e:
|
||||
raise errors.RequesterError(f"无效的 api-key: {e.message}")
|
||||
raise errors.RequesterError(f'无效的 api-key: {e.message}')
|
||||
except openai.NotFoundError as e:
|
||||
raise errors.RequesterError(f"请求路径错误: {e.message}")
|
||||
raise errors.RequesterError(f'请求路径错误: {e.message}')
|
||||
except openai.RateLimitError as e:
|
||||
raise errors.RequesterError(f"请求过于频繁或余额不足: {e.message}")
|
||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
||||
except openai.APIError as e:
|
||||
raise errors.RequesterError(f"请求错误: {e.message}")
|
||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
||||
|
||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
import typing
|
||||
|
||||
from . import chatcmpl
|
||||
from .. import entities, errors, requester
|
||||
from ....core import entities as core_entities, app
|
||||
from .. import errors, requester
|
||||
from ....core import entities as core_entities
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
|
||||
@@ -28,23 +28,23 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = extra_args.copy()
|
||||
args["model"] = use_model.model_entity.name
|
||||
args['model'] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
args['tools'] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages
|
||||
|
||||
# deepseek 不支持多模态,把content都转换成纯文字
|
||||
for m in messages:
|
||||
if 'content' in m and isinstance(m["content"], list):
|
||||
m["content"] = " ".join([c["text"] for c in m["content"]])
|
||||
if 'content' in m and isinstance(m['content'], list):
|
||||
m['content'] = ' '.join([c['text'] for c in m['content']])
|
||||
|
||||
args["messages"] = messages
|
||||
args['messages'] = messages
|
||||
|
||||
# 发送请求
|
||||
resp = await self._req(args)
|
||||
@@ -55,4 +55,4 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
|
||||
return message
|
||||
return message
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import typing
|
||||
|
||||
from . import chatcmpl
|
||||
from .. import entities, errors, requester
|
||||
from ....core import app, entities as core_entities
|
||||
from .. import requester
|
||||
from ....core import entities as core_entities
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
from .. import entities as modelmgr_entities
|
||||
|
||||
|
||||
class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
@@ -33,20 +29,20 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = extra_args.copy()
|
||||
args["model"] = use_model.model_entity.name
|
||||
args['model'] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
args['tools'] = tools
|
||||
|
||||
# gitee 不支持多模态,把content都转换成纯文字
|
||||
for m in req_messages:
|
||||
if 'content' in m and isinstance(m["content"], list):
|
||||
m["content"] = " ".join([c["text"] for c in m["content"]])
|
||||
if 'content' in m and isinstance(m['content'], list):
|
||||
m['content'] = ' '.join([c['text'] for c in m['content']])
|
||||
|
||||
args["messages"] = req_messages
|
||||
args['messages'] = req_messages
|
||||
|
||||
resp = await self._req(args)
|
||||
|
||||
|
||||
@@ -4,8 +4,6 @@ import typing
|
||||
import openai
|
||||
|
||||
from . import chatcmpl
|
||||
from .. import requester
|
||||
from ....core import app
|
||||
|
||||
|
||||
class LmStudioChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
|
||||
@@ -2,11 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from ....core import app
|
||||
|
||||
from . import chatcmpl
|
||||
from .. import entities, errors, requester
|
||||
from ....core import entities as core_entities, app
|
||||
from .. import requester
|
||||
from ....core import entities as core_entities
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
|
||||
@@ -30,26 +29,26 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = extra_args.copy()
|
||||
args["model"] = use_model.model_entity.name
|
||||
args['model'] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
args['tools'] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages
|
||||
|
||||
# deepseek 不支持多模态,把content都转换成纯文字
|
||||
for m in messages:
|
||||
if 'content' in m and isinstance(m["content"], list):
|
||||
m["content"] = " ".join([c["text"] for c in m["content"]])
|
||||
if 'content' in m and isinstance(m['content'], list):
|
||||
m['content'] = ' '.join([c['text'] for c in m['content']])
|
||||
|
||||
# 删除空的
|
||||
messages = [m for m in messages if m["content"].strip() != ""]
|
||||
messages = [m for m in messages if m['content'].strip() != '']
|
||||
|
||||
args["messages"] = messages
|
||||
args['messages'] = messages
|
||||
|
||||
# 发送请求
|
||||
resp = await self._req(args)
|
||||
@@ -57,4 +56,4 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
|
||||
return message
|
||||
return message
|
||||
|
||||
@@ -6,18 +6,15 @@ import typing
|
||||
from typing import Union, Mapping, Any, AsyncIterator
|
||||
import uuid
|
||||
import json
|
||||
import base64
|
||||
|
||||
import async_lru
|
||||
import ollama
|
||||
|
||||
from .. import entities, errors, requester
|
||||
from .. import errors, requester
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
from ....core import app, entities as core_entities
|
||||
from ....utils import image
|
||||
from ....core import entities as core_entities
|
||||
|
||||
REQUESTER_NAME: str = "ollama-chat"
|
||||
REQUESTER_NAME: str = 'ollama-chat'
|
||||
|
||||
|
||||
class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
@@ -26,13 +23,13 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
client: ollama.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
"base_url": "http://127.0.0.1:11434",
|
||||
"timeout": 120,
|
||||
'base_url': 'http://127.0.0.1:11434',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
os.environ["OLLAMA_HOST"] = self.requester_cfg["base_url"]
|
||||
self.client = ollama.AsyncClient(timeout=self.requester_cfg["timeout"])
|
||||
os.environ['OLLAMA_HOST'] = self.requester_cfg['base_url']
|
||||
self.client = ollama.AsyncClient(timeout=self.requester_cfg['timeout'])
|
||||
|
||||
async def _req(
|
||||
self,
|
||||
@@ -49,35 +46,35 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> llm_entities.Message:
|
||||
args = extra_args.copy()
|
||||
args["model"] = use_model.model_entity.name
|
||||
args['model'] = use_model.model_entity.name
|
||||
|
||||
messages: list[dict] = req_messages.copy()
|
||||
for msg in messages:
|
||||
if "content" in msg and isinstance(msg["content"], list):
|
||||
if 'content' in msg and isinstance(msg['content'], list):
|
||||
text_content: list = []
|
||||
image_urls: list = []
|
||||
for me in msg["content"]:
|
||||
if me["type"] == "text":
|
||||
text_content.append(me["text"])
|
||||
elif me["type"] == "image_base64":
|
||||
image_urls.append(me["image_base64"])
|
||||
for me in msg['content']:
|
||||
if me['type'] == 'text':
|
||||
text_content.append(me['text'])
|
||||
elif me['type'] == 'image_base64':
|
||||
image_urls.append(me['image_base64'])
|
||||
|
||||
msg["content"] = "\n".join(text_content)
|
||||
msg["images"] = [url.split(",")[1] for url in image_urls]
|
||||
msg['content'] = '\n'.join(text_content)
|
||||
msg['images'] = [url.split(',')[1] for url in image_urls]
|
||||
if (
|
||||
"tool_calls" in msg
|
||||
'tool_calls' in msg
|
||||
): # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
|
||||
for tool_call in msg["tool_calls"]:
|
||||
tool_call["function"]["arguments"] = json.loads(
|
||||
tool_call["function"]["arguments"]
|
||||
for tool_call in msg['tool_calls']:
|
||||
tool_call['function']['arguments'] = json.loads(
|
||||
tool_call['function']['arguments']
|
||||
)
|
||||
args["messages"] = messages
|
||||
args['messages'] = messages
|
||||
|
||||
args["tools"] = []
|
||||
args['tools'] = []
|
||||
if user_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(user_funcs)
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
args['tools'] = tools
|
||||
|
||||
resp = await self._req(args)
|
||||
message: llm_entities.Message = await self._make_msg(resp)
|
||||
@@ -93,7 +90,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
ret_msg: llm_entities.Message = None
|
||||
|
||||
if message.content is not None:
|
||||
ret_msg = llm_entities.Message(role="assistant", content=message.content)
|
||||
ret_msg = llm_entities.Message(role='assistant', content=message.content)
|
||||
if message.tool_calls is not None and len(message.tool_calls) > 0:
|
||||
tool_calls: list[llm_entities.ToolCall] = []
|
||||
|
||||
@@ -101,7 +98,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
tool_calls.append(
|
||||
llm_entities.ToolCall(
|
||||
id=uuid.uuid4().hex,
|
||||
type="function",
|
||||
type='function',
|
||||
function=llm_entities.FunctionCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=json.dumps(tool_call.function.arguments),
|
||||
@@ -123,13 +120,13 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
req_messages: list = []
|
||||
for m in messages:
|
||||
msg_dict: dict = m.dict(exclude_none=True)
|
||||
content: Any = msg_dict.get("content")
|
||||
content: Any = msg_dict.get('content')
|
||||
if isinstance(content, list):
|
||||
if all(
|
||||
isinstance(part, dict) and part.get("type") == "text"
|
||||
isinstance(part, dict) and part.get('type') == 'text'
|
||||
for part in content
|
||||
):
|
||||
msg_dict["content"] = "\n".join(part["text"] for part in content)
|
||||
msg_dict['content'] = '\n'.join(part['text'] for part in content)
|
||||
req_messages.append(msg_dict)
|
||||
try:
|
||||
return await self._closure(
|
||||
@@ -140,4 +137,4 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
extra_args=extra_args,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError("请求超时")
|
||||
raise errors.RequesterError('请求超时')
|
||||
|
||||
@@ -4,8 +4,6 @@ import typing
|
||||
import openai
|
||||
|
||||
from . import chatcmpl
|
||||
from .. import requester
|
||||
from ....core import app
|
||||
|
||||
|
||||
class SiliconFlowChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
|
||||
@@ -4,8 +4,6 @@ import typing
|
||||
import openai
|
||||
|
||||
from . import chatcmpl
|
||||
from .. import requester
|
||||
from ....core import app
|
||||
|
||||
|
||||
class VolcArkChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
|
||||
@@ -4,8 +4,6 @@ import typing
|
||||
import openai
|
||||
|
||||
from . import chatcmpl
|
||||
from .. import requester
|
||||
from ....core import app
|
||||
|
||||
|
||||
class XaiChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
|
||||
@@ -3,9 +3,7 @@ from __future__ import annotations
|
||||
import typing
|
||||
import openai
|
||||
|
||||
from ....core import app
|
||||
from . import chatcmpl
|
||||
from .. import requester
|
||||
|
||||
|
||||
class ZhipuAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
|
||||
@@ -3,9 +3,8 @@ from __future__ import annotations
|
||||
import typing
|
||||
|
||||
|
||||
class TokenManager():
|
||||
"""鉴权 Token 管理器
|
||||
"""
|
||||
class TokenManager:
|
||||
"""鉴权 Token 管理器"""
|
||||
|
||||
name: str
|
||||
|
||||
@@ -20,6 +19,6 @@ class TokenManager():
|
||||
|
||||
def get_token(self) -> str:
|
||||
return self.tokens[self.using_token_index]
|
||||
|
||||
|
||||
def next_token(self):
|
||||
self.using_token_index = (self.using_token_index + 1) % len(self.tokens)
|
||||
|
||||
@@ -9,9 +9,10 @@ from . import entities as llm_entities
|
||||
|
||||
preregistered_runners: list[typing.Type[RequestRunner]] = []
|
||||
|
||||
|
||||
def runner_class(name: str):
|
||||
"""注册一个请求运行器
|
||||
"""
|
||||
"""注册一个请求运行器"""
|
||||
|
||||
def decorator(cls: typing.Type[RequestRunner]) -> typing.Type[RequestRunner]:
|
||||
cls.name = name
|
||||
preregistered_runners.append(cls)
|
||||
@@ -21,8 +22,8 @@ def runner_class(name: str):
|
||||
|
||||
|
||||
class RequestRunner(abc.ABC):
|
||||
"""请求运行器
|
||||
"""
|
||||
"""请求运行器"""
|
||||
|
||||
name: str = None
|
||||
|
||||
ap: app.Application
|
||||
@@ -34,7 +35,8 @@ class RequestRunner(abc.ABC):
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
@abc.abstractmethod
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求
|
||||
"""
|
||||
async def run(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求"""
|
||||
pass
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import json
|
||||
import base64
|
||||
import re
|
||||
|
||||
import dashscope
|
||||
@@ -10,7 +8,7 @@ import dashscope
|
||||
from .. import runner
|
||||
from ...core import app, entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
from ...utils import image
|
||||
|
||||
|
||||
class DashscopeAPIError(Exception):
|
||||
"""Dashscope API 请求失败"""
|
||||
@@ -20,49 +18,49 @@ class DashscopeAPIError(Exception):
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
@runner.runner_class("dashscope-app-api")
|
||||
@runner.runner_class('dashscope-app-api')
|
||||
class DashScopeAPIRunner(runner.RequestRunner):
|
||||
"阿里云百炼DashsscopeAPI对话请求器"
|
||||
|
||||
|
||||
# 运行器内部使用的配置
|
||||
app_type: str # 应用类型
|
||||
app_id: str # 应用ID
|
||||
api_key: str # API Key
|
||||
references_quote: str # 引用资料提示(当展示回答来源功能开启时,这个变量会作为引用资料名前的提示,可在provider.json中配置)
|
||||
app_type: str # 应用类型
|
||||
app_id: str # 应用ID
|
||||
api_key: str # API Key
|
||||
references_quote: str # 引用资料提示(当展示回答来源功能开启时,这个变量会作为引用资料名前的提示,可在provider.json中配置)
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
"""初始化"""
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
valid_app_types = ["agent", "workflow"]
|
||||
self.app_type = self.pipeline_config["ai"]["dashscope-app-api"]["app-type"]
|
||||
#检查配置文件中使用的应用类型是否支持
|
||||
if (self.app_type not in valid_app_types):
|
||||
raise DashscopeAPIError(
|
||||
f"不支持的 Dashscope 应用类型: {self.app_type}"
|
||||
)
|
||||
|
||||
#初始化Dashscope 参数配置
|
||||
self.app_id = self.pipeline_config["ai"]["dashscope-app-api"]["app-id"]
|
||||
self.api_key = self.pipeline_config["ai"]["dashscope-app-api"]["api-key"]
|
||||
self.references_quote = self.pipeline_config["ai"]["dashscope-app-api"]["references_quote"]
|
||||
|
||||
valid_app_types = ['agent', 'workflow']
|
||||
self.app_type = self.pipeline_config['ai']['dashscope-app-api']['app-type']
|
||||
# 检查配置文件中使用的应用类型是否支持
|
||||
if self.app_type not in valid_app_types:
|
||||
raise DashscopeAPIError(f'不支持的 Dashscope 应用类型: {self.app_type}')
|
||||
|
||||
# 初始化Dashscope 参数配置
|
||||
self.app_id = self.pipeline_config['ai']['dashscope-app-api']['app-id']
|
||||
self.api_key = self.pipeline_config['ai']['dashscope-app-api']['api-key']
|
||||
self.references_quote = self.pipeline_config['ai']['dashscope-app-api'][
|
||||
'references_quote'
|
||||
]
|
||||
|
||||
def _replace_references(self, text, references_dict):
|
||||
"""阿里云百炼平台的自定义应用支持资料引用,此函数可以将引用标签替换为参考资料"""
|
||||
|
||||
|
||||
# 匹配 <ref>[index_id]</ref> 形式的字符串
|
||||
pattern = re.compile(r'<ref>\[(.*?)\]</ref>')
|
||||
|
||||
def replacement(match):
|
||||
# 获取引用编号
|
||||
ref_key = match.group(1)
|
||||
ref_key = match.group(1)
|
||||
if ref_key in references_dict:
|
||||
# 如果有对应的参考资料按照provider.json中的reference_quote返回提示,来自哪个参考资料文件
|
||||
return f"({self.references_quote} {references_dict[ref_key]})"
|
||||
return f'({self.references_quote} {references_dict[ref_key]})'
|
||||
else:
|
||||
# 如果没有对应的参考资料,保留原样
|
||||
return match.group(0)
|
||||
return match.group(0)
|
||||
|
||||
# 使用 re.sub() 进行替换
|
||||
return pattern.sub(replacement, text)
|
||||
@@ -71,14 +69,14 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
self, query: core_entities.Query
|
||||
) -> tuple[str, list[str]]:
|
||||
"""预处理用户消息,提取纯文本,阿里云提供的上传文件方法过于复杂,暂不支持上传文件(包括图片)"""
|
||||
plain_text = ""
|
||||
plain_text = ''
|
||||
image_ids = []
|
||||
if isinstance(query.user_message.content, list):
|
||||
for ce in query.user_message.content:
|
||||
if ce.type == "text":
|
||||
if ce.type == 'text':
|
||||
plain_text += ce.text
|
||||
# 暂时不支持上传图片,保留代码以便后续扩展
|
||||
# elif ce.type == "image_base64":
|
||||
# elif ce.type == "image_base64":
|
||||
# image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
||||
# file_bytes = base64.b64decode(image_b64)
|
||||
# file = ("img.png", file_bytes, f"image/{image_format}")
|
||||
@@ -92,147 +90,141 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
plain_text = query.user_message.content
|
||||
|
||||
return plain_text, image_ids
|
||||
|
||||
|
||||
|
||||
async def _agent_messages(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""Dashscope 智能体对话请求"""
|
||||
|
||||
#局部变量
|
||||
chunk = None # 流式传输的块
|
||||
pending_content = "" # 待处理的Agent输出内容
|
||||
references_dict = {} # 用于存储引用编号和对应的参考资料
|
||||
plain_text = "" # 用户输入的纯文本信息
|
||||
image_ids = [] # 用户输入的图片ID列表 (暂不支持)
|
||||
|
||||
|
||||
# 局部变量
|
||||
chunk = None # 流式传输的块
|
||||
pending_content = '' # 待处理的Agent输出内容
|
||||
references_dict = {} # 用于存储引用编号和对应的参考资料
|
||||
plain_text = '' # 用户输入的纯文本信息
|
||||
image_ids = [] # 用户输入的图片ID列表 (暂不支持)
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
|
||||
#发送对话请求
|
||||
|
||||
# 发送对话请求
|
||||
response = dashscope.Application.call(
|
||||
api_key=self.api_key, # 智能体应用的API Key
|
||||
app_id=self.app_id, # 智能体应用的ID
|
||||
prompt=plain_text, # 用户输入的文本信息
|
||||
stream=True, # 流式输出
|
||||
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
|
||||
session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话
|
||||
api_key=self.api_key, # 智能体应用的API Key
|
||||
app_id=self.app_id, # 智能体应用的ID
|
||||
prompt=plain_text, # 用户输入的文本信息
|
||||
stream=True, # 流式输出
|
||||
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
|
||||
session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话
|
||||
# rag_options={ # 主要用于文件交互,暂不支持
|
||||
# "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个
|
||||
# }
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
if chunk.get("status_code") != 200:
|
||||
if chunk.get('status_code') != 200:
|
||||
raise DashscopeAPIError(
|
||||
f"Dashscope API 请求失败: status_code={chunk.get('status_code')} message={chunk.get('message')} request_id={chunk.get('request_id')} "
|
||||
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
|
||||
)
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
#获取流式传输的output
|
||||
stream_output = chunk.get("output", {})
|
||||
if stream_output.get("text") is not None:
|
||||
pending_content += stream_output.get("text")
|
||||
|
||||
#保存当前会话的session_id用于下次对话的语境
|
||||
query.session.using_conversation.uuid = stream_output.get("session_id")
|
||||
|
||||
#获取模型传出的参考资料列表
|
||||
references_dict_list = stream_output.get("doc_references", [])
|
||||
|
||||
#从模型传出的参考资料信息中提取用于替换的字典
|
||||
|
||||
# 获取流式传输的output
|
||||
stream_output = chunk.get('output', {})
|
||||
if stream_output.get('text') is not None:
|
||||
pending_content += stream_output.get('text')
|
||||
|
||||
# 保存当前会话的session_id用于下次对话的语境
|
||||
query.session.using_conversation.uuid = stream_output.get('session_id')
|
||||
|
||||
# 获取模型传出的参考资料列表
|
||||
references_dict_list = stream_output.get('doc_references', [])
|
||||
|
||||
# 从模型传出的参考资料信息中提取用于替换的字典
|
||||
if references_dict_list is not None:
|
||||
for doc in references_dict_list:
|
||||
if doc.get("index_id") is not None:
|
||||
references_dict[doc.get("index_id")] = doc.get("doc_name")
|
||||
|
||||
#将参考资料替换到文本中
|
||||
if doc.get('index_id') is not None:
|
||||
references_dict[doc.get('index_id')] = doc.get('doc_name')
|
||||
|
||||
# 将参考资料替换到文本中
|
||||
pending_content = self._replace_references(pending_content, references_dict)
|
||||
|
||||
|
||||
yield llm_entities.Message(
|
||||
role="assistant",
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def _workflow_messages(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""Dashscope 工作流对话请求"""
|
||||
|
||||
#局部变量
|
||||
chunk = None # 流式传输的块
|
||||
pending_content = "" # 待处理的Agent输出内容
|
||||
references_dict = {} # 用于存储引用编号和对应的参考资料
|
||||
plain_text = "" # 用户输入的纯文本信息
|
||||
image_ids = [] # 用户输入的图片ID列表 (暂不支持)
|
||||
|
||||
|
||||
# 局部变量
|
||||
chunk = None # 流式传输的块
|
||||
pending_content = '' # 待处理的Agent输出内容
|
||||
references_dict = {} # 用于存储引用编号和对应的参考资料
|
||||
plain_text = '' # 用户输入的纯文本信息
|
||||
image_ids = [] # 用户输入的图片ID列表 (暂不支持)
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
|
||||
biz_params = {}
|
||||
biz_params.update(query.variables)
|
||||
|
||||
#发送对话请求
|
||||
|
||||
# 发送对话请求
|
||||
response = dashscope.Application.call(
|
||||
api_key=self.api_key, # 智能体应用的API Key
|
||||
app_id=self.app_id, # 智能体应用的ID
|
||||
prompt=plain_text, # 用户输入的文本信息
|
||||
stream=True, # 流式输出
|
||||
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
|
||||
session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话
|
||||
biz_params=biz_params, # 工作流应用的自定义输入参数传递
|
||||
api_key=self.api_key, # 智能体应用的API Key
|
||||
app_id=self.app_id, # 智能体应用的ID
|
||||
prompt=plain_text, # 用户输入的文本信息
|
||||
stream=True, # 流式输出
|
||||
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
|
||||
session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话
|
||||
biz_params=biz_params, # 工作流应用的自定义输入参数传递
|
||||
# rag_options={ # 主要用于文件交互,暂不支持
|
||||
# "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个
|
||||
# }
|
||||
)
|
||||
|
||||
#处理API返回的流式输出
|
||||
|
||||
# 处理API返回的流式输出
|
||||
for chunk in response:
|
||||
if chunk.get("status_code") != 200:
|
||||
if chunk.get('status_code') != 200:
|
||||
raise DashscopeAPIError(
|
||||
f"Dashscope API 请求失败: status_code={chunk.get('status_code')} message={chunk.get('message')} request_id={chunk.get('request_id')} "
|
||||
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
|
||||
)
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
#获取流式传输的output
|
||||
stream_output = chunk.get("output", {})
|
||||
if stream_output.get("text") is not None:
|
||||
pending_content += stream_output.get("text")
|
||||
|
||||
#保存当前会话的session_id用于下次对话的语境
|
||||
query.session.using_conversation.uuid = stream_output.get("session_id")
|
||||
|
||||
#获取模型传出的参考资料列表
|
||||
references_dict_list = stream_output.get("doc_references", [])
|
||||
|
||||
#从模型传出的参考资料信息中提取用于替换的字典
|
||||
|
||||
# 获取流式传输的output
|
||||
stream_output = chunk.get('output', {})
|
||||
if stream_output.get('text') is not None:
|
||||
pending_content += stream_output.get('text')
|
||||
|
||||
# 保存当前会话的session_id用于下次对话的语境
|
||||
query.session.using_conversation.uuid = stream_output.get('session_id')
|
||||
|
||||
# 获取模型传出的参考资料列表
|
||||
references_dict_list = stream_output.get('doc_references', [])
|
||||
|
||||
# 从模型传出的参考资料信息中提取用于替换的字典
|
||||
if references_dict_list is not None:
|
||||
for doc in references_dict_list:
|
||||
if doc.get("index_id") is not None:
|
||||
references_dict[doc.get("index_id")] = doc.get("doc_name")
|
||||
|
||||
#将参考资料替换到文本中
|
||||
if doc.get('index_id') is not None:
|
||||
references_dict[doc.get('index_id')] = doc.get('doc_name')
|
||||
|
||||
# 将参考资料替换到文本中
|
||||
pending_content = self._replace_references(pending_content, references_dict)
|
||||
|
||||
|
||||
yield llm_entities.Message(
|
||||
role="assistant",
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
)
|
||||
|
||||
|
||||
async def run(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行"""
|
||||
if self.app_type == "agent":
|
||||
if self.app_type == 'agent':
|
||||
async for msg in self._agent_messages(query):
|
||||
yield msg
|
||||
elif self.app_type == "workflow":
|
||||
elif self.app_type == 'workflow':
|
||||
async for msg in self._workflow_messages(query):
|
||||
yield msg
|
||||
else:
|
||||
raise DashscopeAPIError(
|
||||
f"不支持的 Dashscope 应用类型: {self.app_type}"
|
||||
)
|
||||
|
||||
|
||||
raise DashscopeAPIError(f'不支持的 Dashscope 应用类型: {self.app_type}')
|
||||
|
||||
@@ -5,9 +5,7 @@ import json
|
||||
import uuid
|
||||
import re
|
||||
import base64
|
||||
import datetime
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .. import runner
|
||||
from ...core import app, entities as core_entities
|
||||
@@ -17,7 +15,7 @@ from ...utils import image
|
||||
from libs.dify_service_api.v1 import client, errors
|
||||
|
||||
|
||||
@runner.runner_class("dify-service-api")
|
||||
@runner.runner_class('dify-service-api')
|
||||
class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
"""Dify Service API 对话请求器"""
|
||||
|
||||
@@ -27,38 +25,54 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
valid_app_types = ["chat", "agent", "workflow"]
|
||||
valid_app_types = ['chat', 'agent', 'workflow']
|
||||
if (
|
||||
self.pipeline_config["ai"]["dify-service-api"]["app-type"]
|
||||
self.pipeline_config['ai']['dify-service-api']['app-type']
|
||||
not in valid_app_types
|
||||
):
|
||||
raise errors.DifyAPIError(
|
||||
f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}"
|
||||
f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}'
|
||||
)
|
||||
|
||||
api_key = self.pipeline_config["ai"]["dify-service-api"]["api-key"]
|
||||
api_key = self.pipeline_config['ai']['dify-service-api']['api-key']
|
||||
|
||||
self.dify_client = client.AsyncDifyServiceClient(
|
||||
api_key=api_key,
|
||||
base_url=self.pipeline_config["ai"]["dify-service-api"]["base-url"],
|
||||
base_url=self.pipeline_config['ai']['dify-service-api']['base-url'],
|
||||
)
|
||||
|
||||
def _try_convert_thinking(self, resp_text: str) -> str:
|
||||
"""尝试转换 Dify 的思考提示"""
|
||||
if not resp_text.startswith("<details style=\"color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;\" open> <summary> Thinking... </summary>"):
|
||||
if not resp_text.startswith(
|
||||
'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>'
|
||||
):
|
||||
return resp_text
|
||||
|
||||
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "original":
|
||||
if (
|
||||
self.pipeline_config['ai']['dify-service-api']['thinking-convert']
|
||||
== 'original'
|
||||
):
|
||||
return resp_text
|
||||
|
||||
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "remove":
|
||||
return re.sub(r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>.*?</details>', '', resp_text, flags=re.DOTALL)
|
||||
|
||||
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "plain":
|
||||
|
||||
if (
|
||||
self.pipeline_config['ai']['dify-service-api']['thinking-convert']
|
||||
== 'remove'
|
||||
):
|
||||
return re.sub(
|
||||
r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>.*?</details>',
|
||||
'',
|
||||
resp_text,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
|
||||
if (
|
||||
self.pipeline_config['ai']['dify-service-api']['thinking-convert']
|
||||
== 'plain'
|
||||
):
|
||||
pattern = r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>(.*?)</details>'
|
||||
thinking_text = re.search(pattern, resp_text, flags=re.DOTALL)
|
||||
content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL)
|
||||
return f"<think>{thinking_text.group(1)}</think>\n{content_text}"
|
||||
return f'<think>{thinking_text.group(1)}</think>\n{content_text}'
|
||||
|
||||
async def _preprocess_user_message(
|
||||
self, query: core_entities.Query
|
||||
@@ -68,22 +82,24 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
Returns:
|
||||
tuple[str, list[str]]: 纯文本和图片的 Dify 服务图片 ID
|
||||
"""
|
||||
plain_text = ""
|
||||
plain_text = ''
|
||||
image_ids = []
|
||||
|
||||
if isinstance(query.user_message.content, list):
|
||||
for ce in query.user_message.content:
|
||||
if ce.type == "text":
|
||||
if ce.type == 'text':
|
||||
plain_text += ce.text
|
||||
elif ce.type == "image_base64":
|
||||
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
||||
elif ce.type == 'image_base64':
|
||||
image_b64, image_format = await image.extract_b64_and_format(
|
||||
ce.image_base64
|
||||
)
|
||||
file_bytes = base64.b64decode(image_b64)
|
||||
file = ("img.png", file_bytes, f"image/{image_format}")
|
||||
file = ('img.png', file_bytes, f'image/{image_format}')
|
||||
file_upload_resp = await self.dify_client.upload_file(
|
||||
file,
|
||||
f"{query.session.launcher_type.value}_{query.session.launcher_id}",
|
||||
f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
)
|
||||
image_id = file_upload_resp["id"]
|
||||
image_id = file_upload_resp['id']
|
||||
image_ids.append(image_id)
|
||||
elif isinstance(query.user_message.content, str):
|
||||
plain_text = query.user_message.content
|
||||
@@ -94,116 +110,119 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""调用聊天助手"""
|
||||
cov_id = query.session.using_conversation.uuid or ""
|
||||
cov_id = query.session.using_conversation.uuid or ''
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": image_id,
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': image_id,
|
||||
}
|
||||
for image_id in image_ids
|
||||
]
|
||||
|
||||
mode = "basic" # 标记是基础编排还是工作流编排
|
||||
mode = 'basic' # 标记是基础编排还是工作流编排
|
||||
|
||||
basic_mode_pending_chunk = ''
|
||||
|
||||
inputs = {}
|
||||
|
||||
|
||||
inputs.update(query.variables)
|
||||
|
||||
async for chunk in self.dify_client.chat_messages(
|
||||
inputs=inputs,
|
||||
query=plain_text,
|
||||
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
|
||||
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
conversation_id=cov_id,
|
||||
files=files,
|
||||
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
|
||||
timeout=self.pipeline_config['ai']['dify-service-api']['timeout'],
|
||||
):
|
||||
self.ap.logger.debug("dify-chat-chunk: " + str(chunk))
|
||||
self.ap.logger.debug('dify-chat-chunk: ' + str(chunk))
|
||||
|
||||
if chunk['event'] == 'workflow_started':
|
||||
mode = "workflow"
|
||||
mode = 'workflow'
|
||||
|
||||
if mode == "workflow":
|
||||
if mode == 'workflow':
|
||||
if chunk['event'] == 'node_finished':
|
||||
if chunk['data']['node_type'] == 'answer':
|
||||
yield llm_entities.Message(
|
||||
role="assistant",
|
||||
content=self._try_convert_thinking(chunk['data']['outputs']['answer']),
|
||||
role='assistant',
|
||||
content=self._try_convert_thinking(
|
||||
chunk['data']['outputs']['answer']
|
||||
),
|
||||
)
|
||||
elif mode == "basic":
|
||||
elif mode == 'basic':
|
||||
if chunk['event'] == 'message':
|
||||
basic_mode_pending_chunk += chunk['answer']
|
||||
elif chunk['event'] == 'message_end':
|
||||
yield llm_entities.Message(
|
||||
role="assistant",
|
||||
role='assistant',
|
||||
content=self._try_convert_thinking(basic_mode_pending_chunk),
|
||||
)
|
||||
basic_mode_pending_chunk = ''
|
||||
|
||||
query.session.using_conversation.uuid = chunk["conversation_id"]
|
||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||
|
||||
async def _agent_chat_messages(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""调用聊天助手"""
|
||||
cov_id = query.session.using_conversation.uuid or ""
|
||||
cov_id = query.session.using_conversation.uuid or ''
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": image_id,
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': image_id,
|
||||
}
|
||||
for image_id in image_ids
|
||||
]
|
||||
|
||||
ignored_events = ["agent_message"]
|
||||
ignored_events = ['agent_message']
|
||||
|
||||
inputs = {}
|
||||
|
||||
|
||||
inputs.update(query.variables)
|
||||
|
||||
async for chunk in self.dify_client.chat_messages(
|
||||
inputs=inputs,
|
||||
query=plain_text,
|
||||
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
|
||||
response_mode="streaming",
|
||||
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
response_mode='streaming',
|
||||
conversation_id=cov_id,
|
||||
files=files,
|
||||
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
|
||||
timeout=self.pipeline_config['ai']['dify-service-api']['timeout'],
|
||||
):
|
||||
self.ap.logger.debug("dify-agent-chunk: " + str(chunk))
|
||||
self.ap.logger.debug('dify-agent-chunk: ' + str(chunk))
|
||||
|
||||
if chunk["event"] in ignored_events:
|
||||
if chunk['event'] in ignored_events:
|
||||
continue
|
||||
if chunk["event"] == "agent_thought":
|
||||
|
||||
if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过
|
||||
if chunk['event'] == 'agent_thought':
|
||||
if (
|
||||
chunk['tool'] != '' and chunk['observation'] != ''
|
||||
): # 工具调用结果,跳过
|
||||
continue
|
||||
|
||||
if chunk['thought'].strip() != '': # 文字回复内容
|
||||
msg = llm_entities.Message(
|
||||
role="assistant",
|
||||
content=chunk["thought"],
|
||||
role='assistant',
|
||||
content=chunk['thought'],
|
||||
)
|
||||
yield msg
|
||||
|
||||
if chunk['tool']:
|
||||
msg = llm_entities.Message(
|
||||
role="assistant",
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
llm_entities.ToolCall(
|
||||
id=chunk['id'],
|
||||
type="function",
|
||||
type='function',
|
||||
function=llm_entities.FunctionCall(
|
||||
name=chunk["tool"],
|
||||
name=chunk['tool'],
|
||||
arguments=json.dumps({}),
|
||||
),
|
||||
)
|
||||
@@ -211,9 +230,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
)
|
||||
yield msg
|
||||
if chunk['event'] == 'message_file':
|
||||
|
||||
if chunk['type'] == 'image' and chunk['belongs_to'] == 'assistant':
|
||||
|
||||
base_url = self.dify_client.base_url
|
||||
|
||||
if base_url.endswith('/v1'):
|
||||
@@ -222,11 +239,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
image_url = base_url + chunk['url']
|
||||
|
||||
yield llm_entities.Message(
|
||||
role="assistant",
|
||||
role='assistant',
|
||||
content=[llm_entities.ContentElement.from_image_url(image_url)],
|
||||
)
|
||||
|
||||
query.session.using_conversation.uuid = chunk["conversation_id"]
|
||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||
|
||||
async def _workflow_messages(
|
||||
self, query: core_entities.Query
|
||||
@@ -235,58 +252,57 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
if not query.session.using_conversation.uuid:
|
||||
query.session.using_conversation.uuid = str(uuid.uuid4())
|
||||
|
||||
query.variables["conversation_id"] = query.session.using_conversation.uuid
|
||||
|
||||
query.variables['conversation_id'] = query.session.using_conversation.uuid
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": image_id,
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': image_id,
|
||||
}
|
||||
for image_id in image_ids
|
||||
]
|
||||
|
||||
ignored_events = ["text_chunk", "workflow_started"]
|
||||
ignored_events = ['text_chunk', 'workflow_started']
|
||||
|
||||
inputs = { # these variables are legacy variables, we need to keep them for compatibility
|
||||
"langbot_user_message_text": plain_text,
|
||||
"langbot_session_id": query.variables["session_id"],
|
||||
"langbot_conversation_id": query.variables["conversation_id"],
|
||||
"langbot_msg_create_time": query.variables["msg_create_time"],
|
||||
'langbot_user_message_text': plain_text,
|
||||
'langbot_session_id': query.variables['session_id'],
|
||||
'langbot_conversation_id': query.variables['conversation_id'],
|
||||
'langbot_msg_create_time': query.variables['msg_create_time'],
|
||||
}
|
||||
|
||||
|
||||
inputs.update(query.variables)
|
||||
|
||||
async for chunk in self.dify_client.workflow_run(
|
||||
inputs=inputs,
|
||||
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
|
||||
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
files=files,
|
||||
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
|
||||
timeout=self.pipeline_config['ai']['dify-service-api']['timeout'],
|
||||
):
|
||||
self.ap.logger.debug("dify-workflow-chunk: " + str(chunk))
|
||||
if chunk["event"] in ignored_events:
|
||||
self.ap.logger.debug('dify-workflow-chunk: ' + str(chunk))
|
||||
if chunk['event'] in ignored_events:
|
||||
continue
|
||||
|
||||
if chunk["event"] == "node_started":
|
||||
|
||||
if chunk['event'] == 'node_started':
|
||||
if (
|
||||
chunk["data"]["node_type"] == "start"
|
||||
or chunk["data"]["node_type"] == "end"
|
||||
chunk['data']['node_type'] == 'start'
|
||||
or chunk['data']['node_type'] == 'end'
|
||||
):
|
||||
continue
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role="assistant",
|
||||
role='assistant',
|
||||
content=None,
|
||||
tool_calls=[
|
||||
llm_entities.ToolCall(
|
||||
id=chunk["data"]["node_id"],
|
||||
type="function",
|
||||
id=chunk['data']['node_id'],
|
||||
type='function',
|
||||
function=llm_entities.FunctionCall(
|
||||
name=chunk["data"]["title"],
|
||||
name=chunk['data']['title'],
|
||||
arguments=json.dumps({}),
|
||||
),
|
||||
)
|
||||
@@ -295,13 +311,13 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
yield msg
|
||||
|
||||
elif chunk["event"] == "workflow_finished":
|
||||
elif chunk['event'] == 'workflow_finished':
|
||||
if chunk['data']['error']:
|
||||
raise errors.DifyAPIError(chunk['data']['error'])
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role="assistant",
|
||||
content=chunk["data"]["outputs"]["summary"],
|
||||
role='assistant',
|
||||
content=chunk['data']['outputs']['summary'],
|
||||
)
|
||||
|
||||
yield msg
|
||||
@@ -310,16 +326,16 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求"""
|
||||
if self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "chat":
|
||||
if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat':
|
||||
async for msg in self._chat_messages(query):
|
||||
yield msg
|
||||
elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "agent":
|
||||
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'agent':
|
||||
async for msg in self._agent_chat_messages(query):
|
||||
yield msg
|
||||
elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "workflow":
|
||||
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'workflow':
|
||||
async for msg in self._workflow_messages(query):
|
||||
yield msg
|
||||
else:
|
||||
raise errors.DifyAPIError(
|
||||
f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}"
|
||||
f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}'
|
||||
)
|
||||
|
||||
@@ -4,24 +4,28 @@ import json
|
||||
import typing
|
||||
|
||||
from .. import runner
|
||||
from ...core import app, entities as core_entities
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
|
||||
|
||||
@runner.runner_class("local-agent")
|
||||
@runner.runner_class('local-agent')
|
||||
class LocalAgentRunner(runner.RequestRunner):
|
||||
"""本地Agent请求运行器
|
||||
"""
|
||||
"""本地Agent请求运行器"""
|
||||
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求
|
||||
"""
|
||||
async def run(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求"""
|
||||
pending_tool_calls = []
|
||||
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
||||
req_messages = (
|
||||
query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
||||
)
|
||||
|
||||
# 首次请求
|
||||
msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs)
|
||||
msg = await query.use_llm_model.requester.invoke_llm(
|
||||
query, query.use_llm_model, req_messages, query.use_funcs
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
@@ -34,7 +38,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
for tool_call in pending_tool_calls:
|
||||
try:
|
||||
func = tool_call.function
|
||||
|
||||
|
||||
parameters = json.loads(func.arguments)
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(
|
||||
@@ -42,7 +46,9 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
)
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
|
||||
role='tool',
|
||||
content=json.dumps(func_ret, ensure_ascii=False),
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
|
||||
yield msg
|
||||
@@ -51,7 +57,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
except Exception as e:
|
||||
# 工具调用出错,添加一个报错信息到 req_messages
|
||||
err_msg = llm_entities.Message(
|
||||
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
|
||||
role='tool', content=f'err: {e}', tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield err_msg
|
||||
@@ -59,7 +65,9 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
req_messages.append(err_msg)
|
||||
|
||||
# 处理完所有调用,再次请求
|
||||
msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs)
|
||||
msg = await query.use_llm_model.requester.invoke_llm(
|
||||
query, query.use_llm_model, req_messages, query.use_funcs
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
|
||||
@@ -3,13 +3,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from ...plugin import context as plugin_context
|
||||
from ...provider import entities as provider_entities
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""会话管理器
|
||||
"""
|
||||
"""会话管理器"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
@@ -23,10 +21,12 @@ class SessionManager:
|
||||
pass
|
||||
|
||||
async def get_session(self, query: core_entities.Query) -> core_entities.Session:
|
||||
"""获取会话
|
||||
"""
|
||||
"""获取会话"""
|
||||
for session in self.session_list:
|
||||
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
|
||||
if (
|
||||
query.launcher_type == session.launcher_type
|
||||
and query.launcher_id == session.launcher_id
|
||||
):
|
||||
return session
|
||||
|
||||
session_concurrency = self.ap.instance_config.data['concurrency']['session']
|
||||
@@ -39,7 +39,12 @@ class SessionManager:
|
||||
self.session_list.append(session)
|
||||
return session
|
||||
|
||||
async def get_conversation(self, query: core_entities.Query, session: core_entities.Session, prompt_config: list[dict]) -> core_entities.Conversation:
|
||||
async def get_conversation(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
session: core_entities.Session,
|
||||
prompt_config: list[dict],
|
||||
) -> core_entities.Conversation:
|
||||
"""获取对话或创建对话"""
|
||||
|
||||
if not session.conversations:
|
||||
@@ -52,7 +57,7 @@ class SessionManager:
|
||||
prompt_messages.append(provider_entities.Message(**prompt_message))
|
||||
|
||||
prompt = provider_entities.Prompt(
|
||||
name="default",
|
||||
name='default',
|
||||
messages=prompt_messages,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
import asyncio
|
||||
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from ...core import entities as core_entities
|
||||
|
||||
|
||||
class LLMFunction(pydantic.BaseModel):
|
||||
"""函数"""
|
||||
|
||||
@@ -9,9 +9,10 @@ from . import entities as tools_entities
|
||||
|
||||
preregistered_loaders: list[typing.Type[ToolLoader]] = []
|
||||
|
||||
|
||||
def loader_class(name: str):
|
||||
"""注册一个工具加载器
|
||||
"""
|
||||
"""注册一个工具加载器"""
|
||||
|
||||
def decorator(cls: typing.Type[ToolLoader]) -> typing.Type[ToolLoader]:
|
||||
cls.name = name
|
||||
preregistered_loaders.append(cls)
|
||||
@@ -22,7 +23,7 @@ def loader_class(name: str):
|
||||
|
||||
class ToolLoader(abc.ABC):
|
||||
"""工具加载器"""
|
||||
|
||||
|
||||
name: str = None
|
||||
|
||||
ap: app.Application
|
||||
@@ -34,7 +35,7 @@ class ToolLoader(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
|
||||
async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]:
|
||||
"""获取所有工具"""
|
||||
pass
|
||||
|
||||
@@ -44,11 +45,13 @@ class ToolLoader(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
|
||||
async def invoke_tool(
|
||||
self, query: core_entities.Query, name: str, parameters: dict
|
||||
) -> typing.Any:
|
||||
"""执行工具调用"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def shutdown(self):
|
||||
"""关闭工具"""
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -30,7 +30,7 @@ class RuntimeMCPSession:
|
||||
self.server_name = server_name
|
||||
self.server_config = server_config
|
||||
self.ap = ap
|
||||
|
||||
|
||||
self.session = None
|
||||
|
||||
self.exit_stack = AsyncExitStack()
|
||||
@@ -38,9 +38,9 @@ class RuntimeMCPSession:
|
||||
|
||||
async def _init_stdio_python_server(self):
|
||||
server_params = StdioServerParameters(
|
||||
command=self.server_config["command"],
|
||||
args=self.server_config["args"],
|
||||
env=self.server_config["env"],
|
||||
command=self.server_config['command'],
|
||||
args=self.server_config['args'],
|
||||
env=self.server_config['env'],
|
||||
)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
@@ -58,12 +58,12 @@ class RuntimeMCPSession:
|
||||
async def _init_sse_server(self):
|
||||
sse_transport = await self.exit_stack.enter_async_context(
|
||||
sse_client(
|
||||
self.server_config["url"],
|
||||
headers=self.server_config.get("headers", {}),
|
||||
timeout=self.server_config.get("timeout", 10),
|
||||
self.server_config['url'],
|
||||
headers=self.server_config.get('headers', {}),
|
||||
timeout=self.server_config.get('timeout', 10),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
sseio, write = sse_transport
|
||||
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
@@ -73,18 +73,22 @@ class RuntimeMCPSession:
|
||||
await self.session.initialize()
|
||||
|
||||
async def initialize(self):
|
||||
self.ap.logger.debug(f"初始化 MCP 会话: {self.server_name} {self.server_config}")
|
||||
self.ap.logger.debug(
|
||||
f'初始化 MCP 会话: {self.server_name} {self.server_config}'
|
||||
)
|
||||
|
||||
if self.server_config["mode"] == "stdio":
|
||||
if self.server_config['mode'] == 'stdio':
|
||||
await self._init_stdio_python_server()
|
||||
elif self.server_config["mode"] == "sse":
|
||||
elif self.server_config['mode'] == 'sse':
|
||||
await self._init_sse_server()
|
||||
else:
|
||||
raise ValueError(f"无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}")
|
||||
|
||||
raise ValueError(
|
||||
f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}'
|
||||
)
|
||||
|
||||
tools = await self.session.list_tools()
|
||||
|
||||
self.ap.logger.debug(f"获取 MCP 工具: {tools}")
|
||||
self.ap.logger.debug(f'获取 MCP 工具: {tools}')
|
||||
|
||||
for tool in tools.tools:
|
||||
|
||||
@@ -93,25 +97,28 @@ class RuntimeMCPSession:
|
||||
if result.isError:
|
||||
raise Exception(result.content[0].text)
|
||||
return result.content[0].text
|
||||
|
||||
|
||||
func.__name__ = tool.name
|
||||
|
||||
self.functions.append(tools_entities.LLMFunction(
|
||||
name=tool.name,
|
||||
human_desc=tool.description,
|
||||
description=tool.description,
|
||||
parameters=tool.inputSchema,
|
||||
func=func,
|
||||
))
|
||||
self.functions.append(
|
||||
tools_entities.LLMFunction(
|
||||
name=tool.name,
|
||||
human_desc=tool.description,
|
||||
description=tool.description,
|
||||
parameters=tool.inputSchema,
|
||||
func=func,
|
||||
)
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭工具"""
|
||||
await self.session._exit_stack.aclose()
|
||||
|
||||
@loader.loader_class("mcp")
|
||||
|
||||
@loader.loader_class('mcp')
|
||||
class MCPLoader(loader.ToolLoader):
|
||||
"""MCP 工具加载器。
|
||||
|
||||
|
||||
在此加载器中管理所有与 MCP Server 的连接。
|
||||
"""
|
||||
|
||||
@@ -125,16 +132,17 @@ class MCPLoader(loader.ToolLoader):
|
||||
self._last_listed_functions = []
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
for server_config in self.ap.instance_config.data.get("mcp", {}).get("servers", []):
|
||||
if not server_config["enable"]:
|
||||
for server_config in self.ap.instance_config.data.get('mcp', {}).get(
|
||||
'servers', []
|
||||
):
|
||||
if not server_config['enable']:
|
||||
continue
|
||||
session = RuntimeMCPSession(server_config["name"], server_config, self.ap)
|
||||
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
|
||||
await session.initialize()
|
||||
# self.ap.event_loop.create_task(session.initialize())
|
||||
self.sessions[server_config["name"]] = session
|
||||
self.sessions[server_config['name']] = session
|
||||
|
||||
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
|
||||
async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]:
|
||||
all_functions = []
|
||||
|
||||
for session in self.sessions.values():
|
||||
@@ -147,13 +155,15 @@ class MCPLoader(loader.ToolLoader):
|
||||
async def has_tool(self, name: str) -> bool:
|
||||
return name in [f.name for f in self._last_listed_functions]
|
||||
|
||||
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
|
||||
async def invoke_tool(
|
||||
self, query: core_entities.Query, name: str, parameters: dict
|
||||
) -> typing.Any:
|
||||
for server_name, session in self.sessions.items():
|
||||
for function in session.functions:
|
||||
if function.name == name:
|
||||
return await function.func(query, **parameters)
|
||||
|
||||
raise ValueError(f"未找到工具: {name}")
|
||||
raise ValueError(f'未找到工具: {name}')
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭工具"""
|
||||
|
||||
@@ -4,19 +4,18 @@ import typing
|
||||
import traceback
|
||||
|
||||
from .. import loader, entities as tools_entities
|
||||
from ....core import app, entities as core_entities
|
||||
from ....core import entities as core_entities
|
||||
from ....plugin import context as plugin_context
|
||||
|
||||
|
||||
@loader.loader_class("plugin-tool-loader")
|
||||
@loader.loader_class('plugin-tool-loader')
|
||||
class PluginToolLoader(loader.ToolLoader):
|
||||
"""插件工具加载器。
|
||||
|
||||
|
||||
本加载器中不存储工具信息,仅负责从插件系统中获取工具信息。
|
||||
"""
|
||||
|
||||
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
|
||||
|
||||
async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]:
|
||||
# 从插件系统获取工具(内容函数)
|
||||
all_functions: list[tools_entities.LLMFunction] = []
|
||||
|
||||
@@ -49,23 +48,23 @@ class PluginToolLoader(loader.ToolLoader):
|
||||
return function, plugin.plugin_inst
|
||||
return None, None
|
||||
|
||||
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
|
||||
|
||||
async def invoke_tool(
|
||||
self, query: core_entities.Query, name: str, parameters: dict
|
||||
) -> typing.Any:
|
||||
try:
|
||||
|
||||
function, plugin = await self._get_function_and_plugin(name)
|
||||
if function is None:
|
||||
return None
|
||||
|
||||
parameters = parameters.copy()
|
||||
|
||||
parameters = {"query": query, **parameters}
|
||||
parameters = {'query': query, **parameters}
|
||||
|
||||
return await function.func(plugin, **parameters)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"执行函数 {name} 时发生错误: {e}")
|
||||
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
|
||||
traceback.print_exc()
|
||||
return f"error occurred when executing function {name}: {e}"
|
||||
return f'error occurred when executing function {name}: {e}'
|
||||
finally:
|
||||
plugin = None
|
||||
|
||||
@@ -75,13 +74,12 @@ class PluginToolLoader(loader.ToolLoader):
|
||||
break
|
||||
|
||||
if plugin is not None:
|
||||
|
||||
await self.ap.ctr_mgr.usage.post_function_record(
|
||||
plugin={
|
||||
"name": plugin.plugin_name,
|
||||
"remote": plugin.plugin_repository,
|
||||
"version": plugin.plugin_version,
|
||||
"author": plugin.plugin_author,
|
||||
'name': plugin.plugin_name,
|
||||
'remote': plugin.plugin_repository,
|
||||
'version': plugin.plugin_version,
|
||||
'author': plugin.plugin_author,
|
||||
},
|
||||
function_name=function.name,
|
||||
function_description=function.description,
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities, loader as tools_loader
|
||||
from ...plugin import context as plugin_context
|
||||
from .loaders import plugin, mcp
|
||||
from ...utils import importutil
|
||||
from . import loaders
|
||||
|
||||
importutil.import_modules_in_pkg(loaders)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
@@ -22,13 +23,14 @@ class ToolManager:
|
||||
self.loaders = []
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
for loader_cls in tools_loader.preregistered_loaders:
|
||||
loader_inst = loader_cls(self.ap)
|
||||
await loader_inst.initialize()
|
||||
self.loaders.append(loader_inst)
|
||||
|
||||
async def get_all_functions(self, plugin_enabled: bool=None) -> list[entities.LLMFunction]:
|
||||
async def get_all_functions(
|
||||
self, plugin_enabled: bool = None
|
||||
) -> list[entities.LLMFunction]:
|
||||
"""获取所有函数"""
|
||||
all_functions: list[entities.LLMFunction] = []
|
||||
|
||||
@@ -37,17 +39,19 @@ class ToolManager:
|
||||
|
||||
return all_functions
|
||||
|
||||
async def generate_tools_for_openai(self, use_funcs: list[entities.LLMFunction]) -> list:
|
||||
async def generate_tools_for_openai(
|
||||
self, use_funcs: list[entities.LLMFunction]
|
||||
) -> list:
|
||||
"""生成函数列表"""
|
||||
tools = []
|
||||
|
||||
for function in use_funcs:
|
||||
function_schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
"parameters": function.parameters,
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': function.name,
|
||||
'description': function.description,
|
||||
'parameters': function.parameters,
|
||||
},
|
||||
}
|
||||
tools.append(function_schema)
|
||||
@@ -83,9 +87,9 @@ class ToolManager:
|
||||
|
||||
for function in use_funcs:
|
||||
function_schema = {
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
"input_schema": function.parameters,
|
||||
'name': function.name,
|
||||
'description': function.description,
|
||||
'input_schema': function.parameters,
|
||||
}
|
||||
tools.append(function_schema)
|
||||
|
||||
@@ -100,7 +104,7 @@ class ToolManager:
|
||||
if await loader.has_tool(name):
|
||||
return await loader.invoke_tool(query, name, parameters)
|
||||
else:
|
||||
raise ValueError(f"未找到工具: {name}")
|
||||
raise ValueError(f'未找到工具: {name}')
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭所有工具"""
|
||||
|
||||
Reference in New Issue
Block a user