feat: 添加对 chat 和 workflow 的支持

This commit is contained in:
Junyan Qin
2024-12-14 17:51:11 +08:00
parent 2ea3ff0b5c
commit dbf9f2398e
18 changed files with 1301 additions and 187 deletions

View File

@@ -91,7 +91,7 @@ class Query(pydantic.BaseModel):
class Conversation(pydantic.BaseModel):
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation但只有一个当前使用的 Conversation"""
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation但只有一个当前使用的 Conversation"""
prompt: sysprompt_entities.Prompt
@@ -105,6 +105,9 @@ class Conversation(pydantic.BaseModel):
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
uuid: typing.Optional[str] = None
"""该对话的 uuid在创建时不会自动生成。而是当使用 Dify API 等由外部管理对话信息的服务时,用于绑定外部的会话。具体如何使用,取决于 Runner。"""
class Session(pydantic.BaseModel):
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("dify-service-api-config", 16)
class DifyServiceAPICfgMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'dify-service-api' not in self.ap.provider_cfg.data
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['dify-service-api'] = {
"base-url": "https://api.dify.ai/v1",
"app-type": "chat",
"chat": {
"api-key": "sk-1234567890"
},
"workflow": {
"api-key": "sk-1234567890",
"output-key": "summary"
}
}
await self.ap.provider_cfg.dump_config()

View File

@@ -7,7 +7,7 @@ from .. import migration
from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config, m013_http_api_config, m014_force_delay_config
from ..migrations import m015_gitee_ai_config
from ..migrations import m015_gitee_ai_config, m016_dify_service_api
@stage.stage_class("MigrationStage")

View File

@@ -4,7 +4,7 @@ from . import runner
from ..core import app
from .runners import localagent
from .runners import difyapi
from .runners import difysvapi
class RunnerManager:

View File

@@ -1,205 +1,205 @@
from __future__ import annotations
# from __future__ import annotations
import json
import typing
import aiohttp
# import json
# import typing
# import aiohttp
from .. import runner
from ...core import app, entities as core_entities
from .. import entities as llm_entities
# from .. import runner
# from ...core import app, entities as core_entities
# from .. import entities as llm_entities
api_url = "请求地址/v1"
api_key = "请求key"
user_name = "dify-plugin"
# 需要在dify的自定义字段中另外设置context和system_prompt
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
# api_url = "请求地址/v1"
# api_key = "请求key"
# user_name = "dify-plugin"
# # 需要在dify的自定义字段中另外设置context和system_prompt
# headers = {
# "Authorization": f"Bearer {api_key}",
# "Content-Type": "application/json"
# }
def get_content_text(content):
if isinstance(content, list):
return " ".join(str(element) if element.image_url is None else " " for element in content)
elif isinstance(content, str):
return content
else:
return ""
# def get_content_text(content):
# if isinstance(content, list):
# return " ".join(str(element) if element.image_url is None else " " for element in content)
# elif isinstance(content, str):
# return content
# else:
# return ""
@runner.runner_class("difyapi")
class DifyAgentRunner(runner.RequestRunner):
"""Dify API 对话请求器
"""
# @runner.runner_class("dify-api")
# class DifyAPIRunner(runner.RequestRunner):
# """Dify API 对话请求器
# """
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求"""
await query.use_model.requester.preprocess(query)
# async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
# """运行请求"""
# await query.use_model.requester.preprocess(query)
# 构建系统提示词
prompt_messages = query.prompt.messages.copy()
system_prompt = "\n".join(
f"{msg.role}: {get_content_text(msg.content)}" for msg in prompt_messages if msg.content
)
# # 构建系统提示词
# prompt_messages = query.prompt.messages.copy()
# system_prompt = "\n".join(
# f"{msg.role}: {get_content_text(msg.content)}" for msg in prompt_messages if msg.content
# )
# 构建上下文
previous_messages = query.messages.copy()
user_message = [query.user_message]
# # 构建上下文
# previous_messages = query.messages.copy()
# user_message = [query.user_message]
# 检查 user_message 中的 image_url
image_urls = [element.image_url.url for element in query.user_message.content if element.type == 'image_url' and element.image_url is not None]
# # 检查 user_message 中的 image_url
# image_urls = [element.image_url.url for element in query.user_message.content if element.type == 'image_url' and element.image_url is not None]
if len(image_urls) > 10:
raise ValueError("仅可包含最多10张图片")
# if len(image_urls) > 10:
# raise ValueError("仅可包含最多10张图片")
data = {}
if image_urls:
data["files"] = [
{
"type": "image",
"transfer_method": "remote_url",
"url": url
} for url in image_urls
]
else:
data["files"] = []
# data = {}
# if image_urls:
# data["files"] = [
# {
# "type": "image",
# "transfer_method": "remote_url",
# "url": url
# } for url in image_urls
# ]
# else:
# data["files"] = []
# 继续处理其他逻辑
all_messages = previous_messages + user_message
# # 继续处理其他逻辑
# all_messages = previous_messages + user_message
context = "\n".join(
f"{msg.role}: {get_content_text(msg.content)}" for msg in all_messages if msg.content
)
# context = "\n".join(
# f"{msg.role}: {get_content_text(msg.content)}" for msg in all_messages if msg.content
# )
# 构建请求数据
data.update({
"inputs": {
"context": context,
"system_prompt": system_prompt,
"files": data["files"]
},
"query": get_content_text(query.user_message.content),
"response_mode": "blocking",
"conversation_id": "",
"user": user_name
# # 构建请求数据
# data.update({
# "inputs": {
# "context": context,
# "system_prompt": system_prompt,
# "files": data["files"]
# },
# "query": get_content_text(query.user_message.content),
# "response_mode": "blocking",
# "conversation_id": "",
# "user": user_name
})
# })
async with aiohttp.ClientSession() as session:
try:
async with session.post(api_url + "/chat-messages", headers=headers, json=data) as response:
response_data = await response.json()
response.raise_for_status()
# async with aiohttp.ClientSession() as session:
# try:
# async with session.post(api_url + "/chat-messages", headers=headers, json=data) as response:
# response_data = await response.json()
# response.raise_for_status()
# 处理响应数据
content_elements = [llm_entities.ContentElement.from_text(response_data.get("answer", ""))]
# # 处理响应数据
# content_elements = [llm_entities.ContentElement.from_text(response_data.get("answer", ""))]
msg = llm_entities.Message(
role="assistant",
content=content_elements
)
yield msg
except aiohttp.ClientResponseError as http_err:
if response.status == 404:
error_message = "对话不存在"
elif response.status == 400:
error_code = response_data.get("code")
if error_code == "invalid_param":
error_message = "传入参数异常"
elif error_code == "app_unavailable":
error_message = "App 配置不可用"
elif error_code == "provider_not_initialize":
error_message = "无可用模型凭据配置"
elif error_code == "provider_quota_exceeded":
error_message = "模型调用额度不足"
elif error_code == "model_currently_not_support":
error_message = "当前模型不可用"
elif error_code == "completion_request_error":
error_message = "文本生成失败"
elif response.status == 500:
error_message = "服务内部异常"
else:
error_message = f"HTTP error occurred: {http_err}"
raise ValueError(error_message)
except Exception as err:
raise ValueError(f"An error occurred: {err}")
# msg = llm_entities.Message(
# role="assistant",
# content=content_elements
# )
# yield msg
# except aiohttp.ClientResponseError as http_err:
# if response.status == 404:
# error_message = "对话不存在"
# elif response.status == 400:
# error_code = response_data.get("code")
# if error_code == "invalid_param":
# error_message = "传入参数异常"
# elif error_code == "app_unavailable":
# error_message = "App 配置不可用"
# elif error_code == "provider_not_initialize":
# error_message = "无可用模型凭据配置"
# elif error_code == "provider_quota_exceeded":
# error_message = "模型调用额度不足"
# elif error_code == "model_currently_not_support":
# error_message = "当前模型不可用"
# elif error_code == "completion_request_error":
# error_message = "文本生成失败"
# elif response.status == 500:
# error_message = "服务内部异常"
# else:
# error_message = f"HTTP error occurred: {http_err}"
# raise ValueError(error_message)
# except Exception as err:
# raise ValueError(f"An error occurred: {err}")
@runner.runner_class("local-agent")
class LocalAgentRunner(runner.RequestRunner):
"""本地Agent请求运行器
"""
# # @runner.runner_class("local-agent")
# # class LocalAgentRunner(runner.RequestRunner):
# # """本地Agent请求运行器
# # """
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求
"""
await query.use_model.requester.preprocess(query)
# # async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
# # """运行请求
# # """
# # await query.use_model.requester.preprocess(query)
pending_tool_calls = []
# # pending_tool_calls = []
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
# # req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
# 首次请求
try:
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
if "answer" not in msg.content:
raise ValueError("请求失败返回内容不含answer")
except Exception as e:
err_msg = llm_entities.Message(
role="system", content=f"请求失败:{e}"
)
yield err_msg
return
# # # 首次请求
# # try:
# # msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
# # if "answer" not in msg.content:
# # raise ValueError("请求失败返回内容不含answer")
# # except Exception as e:
# # err_msg = llm_entities.Message(
# # role="system", content=f"请求失败:{e}"
# # )
# # yield err_msg
# # return
yield msg
# # yield msg
pending_tool_calls = msg.tool_calls
# # pending_tool_calls = msg.tool_calls
req_messages.append(msg)
# # req_messages.append(msg)
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
func = tool_call.function
# # # 持续请求,只要还有待处理的工具调用就继续处理调用
# # while pending_tool_calls:
# # for tool_call in pending_tool_calls:
# # try:
# # func = tool_call.function
parameters = json.loads(func.arguments)
# # parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
# # func_ret = await self.ap.tool_mgr.execute_func_call(
# # query, func.name, parameters
# # )
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
)
# # msg = llm_entities.Message(
# # role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
# # )
yield msg
# # yield msg
req_messages.append(msg)
except Exception as e:
# 工具调用出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
)
# # req_messages.append(msg)
# # except Exception as e:
# # # 工具调用出错,添加一个报错信息到 req_messages
# # err_msg = llm_entities.Message(
# # role="tool", content=f"err: {e}", tool_call_id=tool_call.id
# # )
yield err_msg
# # yield err_msg
req_messages.append(err_msg)
# # req_messages.append(err_msg)
# 处理完所有调用,再次请求
try:
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
if "answer" not in msg.content:
raise ValueError("请求失败返回内容不含answer")
except Exception as e:
err_msg = llm_entities.Message(
role="system", content=f"请求失败:{e}"
)
yield err_msg
return
# # # 处理完所有调用,再次请求
# # try:
# # msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
# # if "answer" not in msg.content:
# # raise ValueError("请求失败返回内容不含answer")
# # except Exception as e:
# # err_msg = llm_entities.Message(
# # role="system", content=f"请求失败:{e}"
# # )
# # yield err_msg
# # return
yield msg
# # yield msg
pending_tool_calls = msg.tool_calls
# # pending_tool_calls = msg.tool_calls
req_messages.append(msg)
# # req_messages.append(msg)

View File

@@ -0,0 +1,144 @@
from __future__ import annotations
import typing
import json
import uuid
from .. import runner
from ...core import entities as core_entities
from .. import entities as llm_entities
from ...utils import image
from libs.dify_service_api.v1 import client, errors
@runner.runner_class("dify-service-api")
class DifyServiceAPIRunner(runner.RequestRunner):
"""Dify Service API 对话请求器"""
dify_client: client.AsyncDifyServiceClient
async def initialize(self):
"""初始化"""
valid_app_types = ['chat', 'workflow']
if self.ap.provider_cfg.data['dify-service-api']['app-type'] not in valid_app_types:
raise errors.DifyAPIError(f"不支持的 Dify 应用类型: {self.ap.provider_cfg.data['dify-service-api']['app-type']}")
api_key = self.ap.provider_cfg.data['dify-service-api'][self.ap.provider_cfg.data['dify-service-api']['app-type']]['api-key']
self.dify_client = client.AsyncDifyServiceClient(
api_key=api_key,
base_url=self.ap.provider_cfg.data['dify-service-api']['base-url']
)
async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]:
"""预处理用户消息,提取纯文本,并将图片上传到 Dify 服务
Returns:
tuple[str, list[str]]: 纯文本和图片的 Dify 服务图片 ID
"""
plain_text = ''
image_ids = []
if isinstance(query.user_message.content, list):
for ce in query.user_message.content:
if ce.type == 'text':
plain_text += ce.text
elif ce.type == 'image_url':
file_bytes, image_format = await image.get_qq_image_bytes(ce.image_url.url)
file = ("img.png", file_bytes, f"image/{image_format}")
file_upload_resp = await self.dify_client.upload_file(file, f"{query.session.launcher_type.value}_{query.session.launcher_id}")
image_id = file_upload_resp['id']
image_ids.append(image_id)
elif isinstance(query.user_message.content, str):
plain_text = query.user_message.content
return plain_text, image_ids
async def _chat_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""调用聊天助手"""
cov_id = query.session.using_conversation.uuid or ""
plain_text, image_ids = await self._preprocess_user_message(query)
files = [{
'type': 'image',
'transfer_method': 'local_file',
'upload_file_id': image_id,
} for image_id in image_ids]
resp = await self.dify_client.chat_messages(inputs={}, query=plain_text, user=f"{query.session.launcher_type.value}_{query.session.launcher_id}", conversation_id=cov_id, files=files)
msg = llm_entities.Message(
role='assistant',
content=resp['answer'],
)
yield msg
query.session.using_conversation.uuid = resp['conversation_id']
async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""调用工作流"""
if not query.session.using_conversation.uuid:
query.session.using_conversation.uuid = str(uuid.uuid4())
cov_id = query.session.using_conversation.uuid
plain_text, image_ids = await self._preprocess_user_message(query)
files = [{
'type': 'image',
'transfer_method': 'local_file',
'upload_file_id': image_id,
} for image_id in image_ids]
ignored_events = ['text_chunk', 'workflow_started']
async for chunk in self.dify_client.workflow_run(inputs={
"langbot_user_message_text": plain_text,
"langbot_session_id": f"{query.session.launcher_type.value}_{query.session.launcher_id}",
"langbot_conversation_id": cov_id,
}, user=f"{query.session.launcher_type.value}_{query.session.launcher_id}", files=files):
if chunk['event'] in ignored_events:
continue
if chunk['event'] == 'node_started':
if chunk['data']['node_type'] == 'start' or chunk['data']['node_type'] == 'end':
continue
msg = llm_entities.Message(
role='assistant',
content=None,
tool_calls=[llm_entities.ToolCall(
id=chunk['data']['node_id'],
type='function',
function=llm_entities.FunctionCall(
name=chunk['data']['title'],
arguments=json.dumps({}),
),
)],
)
yield msg
elif chunk['event'] == 'workflow_finished':
msg = llm_entities.Message(
role='assistant',
content=chunk['data']['outputs'][self.ap.provider_cfg.data['dify-service-api']['workflow']['output-key']],
)
yield msg
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求"""
if self.ap.provider_cfg.data['dify-service-api']['app-type'] == 'chat':
async for msg in self._chat_messages(query):
yield msg
elif self.ap.provider_cfg.data['dify-service-api']['app-type'] == 'workflow':
async for msg in self._workflow_messages(query):
yield msg
else:
raise errors.DifyAPIError(f"不支持的 Dify 应用类型: {self.ap.provider_cfg.data['dify-service-api']['app-type']}")

View File

@@ -6,6 +6,31 @@ import ssl
import aiohttp
def get_qq_image_downloadable_url(image_url: str) -> tuple[str, dict]:
"""获取QQ图片的下载链接"""
parsed = urlparse(image_url)
query = parse_qs(parsed.query)
return f"http://{parsed.netloc}{parsed.path}", query
async def get_qq_image_bytes(image_url: str) -> tuple[bytes, str]:
"""获取QQ图片的bytes"""
image_url, query = get_qq_image_downloadable_url(image_url)
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with aiohttp.ClientSession(trust_env=False) as session:
async with session.get(image_url, params=query, ssl=ssl_context) as resp:
resp.raise_for_status()
file_bytes = await resp.read()
content_type = resp.headers.get('Content-Type')
if not content_type or not content_type.startswith('image/'):
image_format = 'jpeg'
else:
image_format = content_type.split('/')[-1]
return file_bytes, image_format
async def qq_image_url_to_base64(
image_url: str
) -> typing.Tuple[str, str]:
@@ -17,29 +42,12 @@ async def qq_image_url_to_base64(
Returns:
typing.Tuple[str, str]: base64编码和图片格式
"""
parsed = urlparse(image_url)
query = parse_qs(parsed.query)
image_url, query = get_qq_image_downloadable_url(image_url)
# Flatten the query dictionary
query = {k: v[0] for k, v in query.items()}
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with aiohttp.ClientSession(trust_env=False) as session:
async with session.get(
f"http://{parsed.netloc}{parsed.path}",
params=query,
ssl=ssl_context
) as resp:
resp.raise_for_status() # 检查HTTP错误
file_bytes = await resp.read()
content_type = resp.headers.get('Content-Type')
if not content_type or not content_type.startswith('image/'):
image_format = 'jpeg'
else:
image_format = content_type.split('/')[-1]
file_bytes, image_format = await get_qq_image_bytes(image_url)
base64_str = base64.b64encode(file_bytes).decode()