style: introduce ruff as linter and formatter (#1356)

* style: remove necessary imports

* style: fix F841

* style: fix F401

* style: fix F811

* style: fix E402

* style: fix E721

* style: fix E722

* style: fix E722

* style: fix F541

* style: ruff format

* style: all passed

* style: add ruff in deps

* style: more ignores in ruff.toml

* style: add pre-commit
This commit is contained in:
Junyan Qin (Chin)
2025-04-29 17:24:07 +08:00
committed by GitHub
parent 09e70d70e9
commit 209f16af76
240 changed files with 5307 additions and 4689 deletions
+110 -118
View File
@@ -1,8 +1,6 @@
from __future__ import annotations
import typing
import json
import base64
import re
import dashscope
@@ -10,7 +8,7 @@ import dashscope
from .. import runner
from ...core import app, entities as core_entities
from .. import entities as llm_entities
from ...utils import image
class DashscopeAPIError(Exception):
"""Dashscope API 请求失败"""
@@ -20,49 +18,49 @@ class DashscopeAPIError(Exception):
super().__init__(self.message)
@runner.runner_class("dashscope-app-api")
@runner.runner_class('dashscope-app-api')
class DashScopeAPIRunner(runner.RequestRunner):
"阿里云百炼DashsscopeAPI对话请求器"
# 运行器内部使用的配置
app_type: str # 应用类型
app_id: str # 应用ID
api_key: str # API Key
references_quote: str # 引用资料提示(当展示回答来源功能开启时,这个变量会作为引用资料名前的提示,可在provider.json中配置)
app_type: str # 应用类型
app_id: str # 应用ID
api_key: str # API Key
references_quote: str # 引用资料提示(当展示回答来源功能开启时,这个变量会作为引用资料名前的提示,可在provider.json中配置)
def __init__(self, ap: app.Application, pipeline_config: dict):
"""初始化"""
self.ap = ap
self.pipeline_config = pipeline_config
valid_app_types = ["agent", "workflow"]
self.app_type = self.pipeline_config["ai"]["dashscope-app-api"]["app-type"]
#检查配置文件中使用的应用类型是否支持
if (self.app_type not in valid_app_types):
raise DashscopeAPIError(
f"不支持的 Dashscope 应用类型: {self.app_type}"
)
#初始化Dashscope 参数配置
self.app_id = self.pipeline_config["ai"]["dashscope-app-api"]["app-id"]
self.api_key = self.pipeline_config["ai"]["dashscope-app-api"]["api-key"]
self.references_quote = self.pipeline_config["ai"]["dashscope-app-api"]["references_quote"]
valid_app_types = ['agent', 'workflow']
self.app_type = self.pipeline_config['ai']['dashscope-app-api']['app-type']
# 检查配置文件中使用的应用类型是否支持
if self.app_type not in valid_app_types:
raise DashscopeAPIError(f'不支持的 Dashscope 应用类型: {self.app_type}')
# 初始化Dashscope 参数配置
self.app_id = self.pipeline_config['ai']['dashscope-app-api']['app-id']
self.api_key = self.pipeline_config['ai']['dashscope-app-api']['api-key']
self.references_quote = self.pipeline_config['ai']['dashscope-app-api'][
'references_quote'
]
def _replace_references(self, text, references_dict):
"""阿里云百炼平台的自定义应用支持资料引用,此函数可以将引用标签替换为参考资料"""
# 匹配 <ref>[index_id]</ref> 形式的字符串
pattern = re.compile(r'<ref>\[(.*?)\]</ref>')
def replacement(match):
# 获取引用编号
ref_key = match.group(1)
ref_key = match.group(1)
if ref_key in references_dict:
# 如果有对应的参考资料按照provider.json中的reference_quote返回提示,来自哪个参考资料文件
return f"({self.references_quote} {references_dict[ref_key]})"
return f'({self.references_quote} {references_dict[ref_key]})'
else:
# 如果没有对应的参考资料,保留原样
return match.group(0)
return match.group(0)
# 使用 re.sub() 进行替换
return pattern.sub(replacement, text)
@@ -71,14 +69,14 @@ class DashScopeAPIRunner(runner.RequestRunner):
self, query: core_entities.Query
) -> tuple[str, list[str]]:
"""预处理用户消息,提取纯文本,阿里云提供的上传文件方法过于复杂,暂不支持上传文件(包括图片)"""
plain_text = ""
plain_text = ''
image_ids = []
if isinstance(query.user_message.content, list):
for ce in query.user_message.content:
if ce.type == "text":
if ce.type == 'text':
plain_text += ce.text
# 暂时不支持上传图片,保留代码以便后续扩展
# elif ce.type == "image_base64":
# elif ce.type == "image_base64":
# image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
# file_bytes = base64.b64decode(image_b64)
# file = ("img.png", file_bytes, f"image/{image_format}")
@@ -92,147 +90,141 @@ class DashScopeAPIRunner(runner.RequestRunner):
plain_text = query.user_message.content
return plain_text, image_ids
async def _agent_messages(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""Dashscope 智能体对话请求"""
#局部变量
chunk = None # 流式传输的块
pending_content = "" # 待处理的Agent输出内容
references_dict = {} # 用于存储引用编号和对应的参考资料
plain_text = "" # 用户输入的纯文本信息
image_ids = [] # 用户输入的图片ID列表 (暂不支持)
# 局部变量
chunk = None # 流式传输的块
pending_content = '' # 待处理的Agent输出内容
references_dict = {} # 用于存储引用编号和对应的参考资料
plain_text = '' # 用户输入的纯文本信息
image_ids = [] # 用户输入的图片ID列表 (暂不支持)
plain_text, image_ids = await self._preprocess_user_message(query)
#发送对话请求
# 发送对话请求
response = dashscope.Application.call(
api_key=self.api_key, # 智能体应用的API Key
app_id=self.app_id, # 智能体应用的ID
prompt=plain_text, # 用户输入的文本信息
stream=True, # 流式输出
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话
api_key=self.api_key, # 智能体应用的API Key
app_id=self.app_id, # 智能体应用的ID
prompt=plain_text, # 用户输入的文本信息
stream=True, # 流式输出
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话
# rag_options={ # 主要用于文件交互,暂不支持
# "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个
# }
)
for chunk in response:
if chunk.get("status_code") != 200:
if chunk.get('status_code') != 200:
raise DashscopeAPIError(
f"Dashscope API 请求失败: status_code={chunk.get('status_code')} message={chunk.get('message')} request_id={chunk.get('request_id')} "
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
)
if not chunk:
continue
#获取流式传输的output
stream_output = chunk.get("output", {})
if stream_output.get("text") is not None:
pending_content += stream_output.get("text")
#保存当前会话的session_id用于下次对话的语境
query.session.using_conversation.uuid = stream_output.get("session_id")
#获取模型传出的参考资料列表
references_dict_list = stream_output.get("doc_references", [])
#从模型传出的参考资料信息中提取用于替换的字典
# 获取流式传输的output
stream_output = chunk.get('output', {})
if stream_output.get('text') is not None:
pending_content += stream_output.get('text')
# 保存当前会话的session_id用于下次对话的语境
query.session.using_conversation.uuid = stream_output.get('session_id')
# 获取模型传出的参考资料列表
references_dict_list = stream_output.get('doc_references', [])
# 从模型传出的参考资料信息中提取用于替换的字典
if references_dict_list is not None:
for doc in references_dict_list:
if doc.get("index_id") is not None:
references_dict[doc.get("index_id")] = doc.get("doc_name")
#将参考资料替换到文本中
if doc.get('index_id') is not None:
references_dict[doc.get('index_id')] = doc.get('doc_name')
# 将参考资料替换到文本中
pending_content = self._replace_references(pending_content, references_dict)
yield llm_entities.Message(
role="assistant",
role='assistant',
content=pending_content,
)
async def _workflow_messages(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""Dashscope 工作流对话请求"""
#局部变量
chunk = None # 流式传输的块
pending_content = "" # 待处理的Agent输出内容
references_dict = {} # 用于存储引用编号和对应的参考资料
plain_text = "" # 用户输入的纯文本信息
image_ids = [] # 用户输入的图片ID列表 (暂不支持)
# 局部变量
chunk = None # 流式传输的块
pending_content = '' # 待处理的Agent输出内容
references_dict = {} # 用于存储引用编号和对应的参考资料
plain_text = '' # 用户输入的纯文本信息
image_ids = [] # 用户输入的图片ID列表 (暂不支持)
plain_text, image_ids = await self._preprocess_user_message(query)
biz_params = {}
biz_params.update(query.variables)
#发送对话请求
# 发送对话请求
response = dashscope.Application.call(
api_key=self.api_key, # 智能体应用的API Key
app_id=self.app_id, # 智能体应用的ID
prompt=plain_text, # 用户输入的文本信息
stream=True, # 流式输出
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话
biz_params=biz_params, # 工作流应用的自定义输入参数传递
api_key=self.api_key, # 智能体应用的API Key
app_id=self.app_id, # 智能体应用的ID
prompt=plain_text, # 用户输入的文本信息
stream=True, # 流式输出
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话
biz_params=biz_params, # 工作流应用的自定义输入参数传递
# rag_options={ # 主要用于文件交互,暂不支持
# "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个
# }
)
#处理API返回的流式输出
# 处理API返回的流式输出
for chunk in response:
if chunk.get("status_code") != 200:
if chunk.get('status_code') != 200:
raise DashscopeAPIError(
f"Dashscope API 请求失败: status_code={chunk.get('status_code')} message={chunk.get('message')} request_id={chunk.get('request_id')} "
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
)
if not chunk:
continue
#获取流式传输的output
stream_output = chunk.get("output", {})
if stream_output.get("text") is not None:
pending_content += stream_output.get("text")
#保存当前会话的session_id用于下次对话的语境
query.session.using_conversation.uuid = stream_output.get("session_id")
#获取模型传出的参考资料列表
references_dict_list = stream_output.get("doc_references", [])
#从模型传出的参考资料信息中提取用于替换的字典
# 获取流式传输的output
stream_output = chunk.get('output', {})
if stream_output.get('text') is not None:
pending_content += stream_output.get('text')
# 保存当前会话的session_id用于下次对话的语境
query.session.using_conversation.uuid = stream_output.get('session_id')
# 获取模型传出的参考资料列表
references_dict_list = stream_output.get('doc_references', [])
# 从模型传出的参考资料信息中提取用于替换的字典
if references_dict_list is not None:
for doc in references_dict_list:
if doc.get("index_id") is not None:
references_dict[doc.get("index_id")] = doc.get("doc_name")
#将参考资料替换到文本中
if doc.get('index_id') is not None:
references_dict[doc.get('index_id')] = doc.get('doc_name')
# 将参考资料替换到文本中
pending_content = self._replace_references(pending_content, references_dict)
yield llm_entities.Message(
role="assistant",
role='assistant',
content=pending_content,
)
async def run(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行"""
if self.app_type == "agent":
if self.app_type == 'agent':
async for msg in self._agent_messages(query):
yield msg
elif self.app_type == "workflow":
elif self.app_type == 'workflow':
async for msg in self._workflow_messages(query):
yield msg
else:
raise DashscopeAPIError(
f"不支持的 Dashscope 应用类型: {self.app_type}"
)
raise DashscopeAPIError(f'不支持的 Dashscope 应用类型: {self.app_type}')
+108 -92
View File
@@ -5,9 +5,7 @@ import json
import uuid
import re
import base64
import datetime
import aiohttp
from .. import runner
from ...core import app, entities as core_entities
@@ -17,7 +15,7 @@ from ...utils import image
from libs.dify_service_api.v1 import client, errors
@runner.runner_class("dify-service-api")
@runner.runner_class('dify-service-api')
class DifyServiceAPIRunner(runner.RequestRunner):
"""Dify Service API 对话请求器"""
@@ -27,38 +25,54 @@ class DifyServiceAPIRunner(runner.RequestRunner):
self.ap = ap
self.pipeline_config = pipeline_config
valid_app_types = ["chat", "agent", "workflow"]
valid_app_types = ['chat', 'agent', 'workflow']
if (
self.pipeline_config["ai"]["dify-service-api"]["app-type"]
self.pipeline_config['ai']['dify-service-api']['app-type']
not in valid_app_types
):
raise errors.DifyAPIError(
f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}"
f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}'
)
api_key = self.pipeline_config["ai"]["dify-service-api"]["api-key"]
api_key = self.pipeline_config['ai']['dify-service-api']['api-key']
self.dify_client = client.AsyncDifyServiceClient(
api_key=api_key,
base_url=self.pipeline_config["ai"]["dify-service-api"]["base-url"],
base_url=self.pipeline_config['ai']['dify-service-api']['base-url'],
)
def _try_convert_thinking(self, resp_text: str) -> str:
"""尝试转换 Dify 的思考提示"""
if not resp_text.startswith("<details style=\"color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;\" open> <summary> Thinking... </summary>"):
if not resp_text.startswith(
'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>'
):
return resp_text
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "original":
if (
self.pipeline_config['ai']['dify-service-api']['thinking-convert']
== 'original'
):
return resp_text
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "remove":
return re.sub(r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>.*?</details>', '', resp_text, flags=re.DOTALL)
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "plain":
if (
self.pipeline_config['ai']['dify-service-api']['thinking-convert']
== 'remove'
):
return re.sub(
r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>.*?</details>',
'',
resp_text,
flags=re.DOTALL,
)
if (
self.pipeline_config['ai']['dify-service-api']['thinking-convert']
== 'plain'
):
pattern = r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>(.*?)</details>'
thinking_text = re.search(pattern, resp_text, flags=re.DOTALL)
content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL)
return f"<think>{thinking_text.group(1)}</think>\n{content_text}"
return f'<think>{thinking_text.group(1)}</think>\n{content_text}'
async def _preprocess_user_message(
self, query: core_entities.Query
@@ -68,22 +82,24 @@ class DifyServiceAPIRunner(runner.RequestRunner):
Returns:
tuple[str, list[str]]: 纯文本和图片的 Dify 服务图片 ID
"""
plain_text = ""
plain_text = ''
image_ids = []
if isinstance(query.user_message.content, list):
for ce in query.user_message.content:
if ce.type == "text":
if ce.type == 'text':
plain_text += ce.text
elif ce.type == "image_base64":
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
elif ce.type == 'image_base64':
image_b64, image_format = await image.extract_b64_and_format(
ce.image_base64
)
file_bytes = base64.b64decode(image_b64)
file = ("img.png", file_bytes, f"image/{image_format}")
file = ('img.png', file_bytes, f'image/{image_format}')
file_upload_resp = await self.dify_client.upload_file(
file,
f"{query.session.launcher_type.value}_{query.session.launcher_id}",
f'{query.session.launcher_type.value}_{query.session.launcher_id}',
)
image_id = file_upload_resp["id"]
image_id = file_upload_resp['id']
image_ids.append(image_id)
elif isinstance(query.user_message.content, str):
plain_text = query.user_message.content
@@ -94,116 +110,119 @@ class DifyServiceAPIRunner(runner.RequestRunner):
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""调用聊天助手"""
cov_id = query.session.using_conversation.uuid or ""
cov_id = query.session.using_conversation.uuid or ''
plain_text, image_ids = await self._preprocess_user_message(query)
files = [
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": image_id,
'type': 'image',
'transfer_method': 'local_file',
'upload_file_id': image_id,
}
for image_id in image_ids
]
mode = "basic" # 标记是基础编排还是工作流编排
mode = 'basic' # 标记是基础编排还是工作流编排
basic_mode_pending_chunk = ''
inputs = {}
inputs.update(query.variables)
async for chunk in self.dify_client.chat_messages(
inputs=inputs,
query=plain_text,
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
conversation_id=cov_id,
files=files,
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
timeout=self.pipeline_config['ai']['dify-service-api']['timeout'],
):
self.ap.logger.debug("dify-chat-chunk: " + str(chunk))
self.ap.logger.debug('dify-chat-chunk: ' + str(chunk))
if chunk['event'] == 'workflow_started':
mode = "workflow"
mode = 'workflow'
if mode == "workflow":
if mode == 'workflow':
if chunk['event'] == 'node_finished':
if chunk['data']['node_type'] == 'answer':
yield llm_entities.Message(
role="assistant",
content=self._try_convert_thinking(chunk['data']['outputs']['answer']),
role='assistant',
content=self._try_convert_thinking(
chunk['data']['outputs']['answer']
),
)
elif mode == "basic":
elif mode == 'basic':
if chunk['event'] == 'message':
basic_mode_pending_chunk += chunk['answer']
elif chunk['event'] == 'message_end':
yield llm_entities.Message(
role="assistant",
role='assistant',
content=self._try_convert_thinking(basic_mode_pending_chunk),
)
basic_mode_pending_chunk = ''
query.session.using_conversation.uuid = chunk["conversation_id"]
query.session.using_conversation.uuid = chunk['conversation_id']
async def _agent_chat_messages(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""调用聊天助手"""
cov_id = query.session.using_conversation.uuid or ""
cov_id = query.session.using_conversation.uuid or ''
plain_text, image_ids = await self._preprocess_user_message(query)
files = [
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": image_id,
'type': 'image',
'transfer_method': 'local_file',
'upload_file_id': image_id,
}
for image_id in image_ids
]
ignored_events = ["agent_message"]
ignored_events = ['agent_message']
inputs = {}
inputs.update(query.variables)
async for chunk in self.dify_client.chat_messages(
inputs=inputs,
query=plain_text,
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
response_mode="streaming",
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
response_mode='streaming',
conversation_id=cov_id,
files=files,
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
timeout=self.pipeline_config['ai']['dify-service-api']['timeout'],
):
self.ap.logger.debug("dify-agent-chunk: " + str(chunk))
self.ap.logger.debug('dify-agent-chunk: ' + str(chunk))
if chunk["event"] in ignored_events:
if chunk['event'] in ignored_events:
continue
if chunk["event"] == "agent_thought":
if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过
if chunk['event'] == 'agent_thought':
if (
chunk['tool'] != '' and chunk['observation'] != ''
): # 工具调用结果,跳过
continue
if chunk['thought'].strip() != '': # 文字回复内容
msg = llm_entities.Message(
role="assistant",
content=chunk["thought"],
role='assistant',
content=chunk['thought'],
)
yield msg
if chunk['tool']:
msg = llm_entities.Message(
role="assistant",
role='assistant',
tool_calls=[
llm_entities.ToolCall(
id=chunk['id'],
type="function",
type='function',
function=llm_entities.FunctionCall(
name=chunk["tool"],
name=chunk['tool'],
arguments=json.dumps({}),
),
)
@@ -211,9 +230,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
)
yield msg
if chunk['event'] == 'message_file':
if chunk['type'] == 'image' and chunk['belongs_to'] == 'assistant':
base_url = self.dify_client.base_url
if base_url.endswith('/v1'):
@@ -222,11 +239,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
image_url = base_url + chunk['url']
yield llm_entities.Message(
role="assistant",
role='assistant',
content=[llm_entities.ContentElement.from_image_url(image_url)],
)
query.session.using_conversation.uuid = chunk["conversation_id"]
query.session.using_conversation.uuid = chunk['conversation_id']
async def _workflow_messages(
self, query: core_entities.Query
@@ -235,58 +252,57 @@ class DifyServiceAPIRunner(runner.RequestRunner):
if not query.session.using_conversation.uuid:
query.session.using_conversation.uuid = str(uuid.uuid4())
query.variables["conversation_id"] = query.session.using_conversation.uuid
query.variables['conversation_id'] = query.session.using_conversation.uuid
plain_text, image_ids = await self._preprocess_user_message(query)
files = [
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": image_id,
'type': 'image',
'transfer_method': 'local_file',
'upload_file_id': image_id,
}
for image_id in image_ids
]
ignored_events = ["text_chunk", "workflow_started"]
ignored_events = ['text_chunk', 'workflow_started']
inputs = { # these variables are legacy variables, we need to keep them for compatibility
"langbot_user_message_text": plain_text,
"langbot_session_id": query.variables["session_id"],
"langbot_conversation_id": query.variables["conversation_id"],
"langbot_msg_create_time": query.variables["msg_create_time"],
'langbot_user_message_text': plain_text,
'langbot_session_id': query.variables['session_id'],
'langbot_conversation_id': query.variables['conversation_id'],
'langbot_msg_create_time': query.variables['msg_create_time'],
}
inputs.update(query.variables)
async for chunk in self.dify_client.workflow_run(
inputs=inputs,
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
files=files,
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
timeout=self.pipeline_config['ai']['dify-service-api']['timeout'],
):
self.ap.logger.debug("dify-workflow-chunk: " + str(chunk))
if chunk["event"] in ignored_events:
self.ap.logger.debug('dify-workflow-chunk: ' + str(chunk))
if chunk['event'] in ignored_events:
continue
if chunk["event"] == "node_started":
if chunk['event'] == 'node_started':
if (
chunk["data"]["node_type"] == "start"
or chunk["data"]["node_type"] == "end"
chunk['data']['node_type'] == 'start'
or chunk['data']['node_type'] == 'end'
):
continue
msg = llm_entities.Message(
role="assistant",
role='assistant',
content=None,
tool_calls=[
llm_entities.ToolCall(
id=chunk["data"]["node_id"],
type="function",
id=chunk['data']['node_id'],
type='function',
function=llm_entities.FunctionCall(
name=chunk["data"]["title"],
name=chunk['data']['title'],
arguments=json.dumps({}),
),
)
@@ -295,13 +311,13 @@ class DifyServiceAPIRunner(runner.RequestRunner):
yield msg
elif chunk["event"] == "workflow_finished":
elif chunk['event'] == 'workflow_finished':
if chunk['data']['error']:
raise errors.DifyAPIError(chunk['data']['error'])
msg = llm_entities.Message(
role="assistant",
content=chunk["data"]["outputs"]["summary"],
role='assistant',
content=chunk['data']['outputs']['summary'],
)
yield msg
@@ -310,16 +326,16 @@ class DifyServiceAPIRunner(runner.RequestRunner):
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求"""
if self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "chat":
if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat':
async for msg in self._chat_messages(query):
yield msg
elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "agent":
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'agent':
async for msg in self._agent_chat_messages(query):
yield msg
elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "workflow":
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'workflow':
async for msg in self._workflow_messages(query):
yield msg
else:
raise errors.DifyAPIError(
f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}"
f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}'
)
+21 -13
View File
@@ -4,24 +4,28 @@ import json
import typing
from .. import runner
from ...core import app, entities as core_entities
from ...core import entities as core_entities
from .. import entities as llm_entities
@runner.runner_class("local-agent")
@runner.runner_class('local-agent')
class LocalAgentRunner(runner.RequestRunner):
"""本地Agent请求运行器
"""
"""本地Agent请求运行器"""
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求
"""
async def run(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求"""
pending_tool_calls = []
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
req_messages = (
query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
)
# 首次请求
msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs)
msg = await query.use_llm_model.requester.invoke_llm(
query, query.use_llm_model, req_messages, query.use_funcs
)
yield msg
@@ -34,7 +38,7 @@ class LocalAgentRunner(runner.RequestRunner):
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
@@ -42,7 +46,9 @@ class LocalAgentRunner(runner.RequestRunner):
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
role='tool',
content=json.dumps(func_ret, ensure_ascii=False),
tool_call_id=tool_call.id,
)
yield msg
@@ -51,7 +57,7 @@ class LocalAgentRunner(runner.RequestRunner):
except Exception as e:
# 工具调用出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
role='tool', content=f'err: {e}', tool_call_id=tool_call.id
)
yield err_msg
@@ -59,7 +65,9 @@ class LocalAgentRunner(runner.RequestRunner):
req_messages.append(err_msg)
# 处理完所有调用,再次请求
msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs)
msg = await query.use_llm_model.requester.invoke_llm(
query, query.use_llm_model, req_messages, query.use_funcs
)
yield msg