refactor: switch llm_entities to plugin sdk

This commit is contained in:
Junyan Qin
2025-07-13 20:30:17 +08:00
parent 4a319b2b20
commit 6a1de889b4
15 changed files with 76 additions and 378 deletions

View File

@@ -1,171 +0,0 @@
from __future__ import annotations
import typing
import pydantic.v1 as pydantic
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.provider.session as provider_session
from ..provider import entities as llm_entities
class BaseEventModel(pydantic.BaseModel):
"""事件模型基类"""
query: typing.Union[pipeline_query.Query, None]
"""此次请求的query对象非请求过程的事件时为None"""
class Config:
arbitrary_types_allowed = True
class PersonMessageReceived(BaseEventModel):
"""收到任何私聊消息时"""
launcher_type: str
"""发起对象类型(group/person)"""
launcher_id: typing.Union[int, str]
"""发起对象ID(群号/QQ号)"""
sender_id: typing.Union[int, str]
"""发送者ID(QQ号)"""
message_chain: platform_message.MessageChain
class GroupMessageReceived(BaseEventModel):
"""收到任何群聊消息时"""
launcher_type: str
launcher_id: typing.Union[int, str]
sender_id: typing.Union[int, str]
message_chain: platform_message.MessageChain
class PersonNormalMessageReceived(BaseEventModel):
"""判断为应该处理的私聊普通消息时触发"""
launcher_type: str
launcher_id: typing.Union[int, str]
sender_id: typing.Union[int, str]
text_message: str
alter: typing.Optional[str] = None
"""修改后的消息文本"""
reply: typing.Optional[list] = None
"""回复消息组件列表"""
class PersonCommandSent(BaseEventModel):
"""判断为应该处理的私聊命令时触发"""
launcher_type: str
launcher_id: typing.Union[int, str]
sender_id: typing.Union[int, str]
command: str
params: list[str]
text_message: str
is_admin: bool
alter: typing.Optional[str] = None
"""修改后的完整命令文本"""
reply: typing.Optional[list] = None
"""回复消息组件列表"""
class GroupNormalMessageReceived(BaseEventModel):
"""判断为应该处理的群聊普通消息时触发"""
launcher_type: str
launcher_id: typing.Union[int, str]
sender_id: typing.Union[int, str]
text_message: str
alter: typing.Optional[str] = None
"""修改后的消息文本"""
reply: typing.Optional[list] = None
"""回复消息组件列表"""
class GroupCommandSent(BaseEventModel):
"""判断为应该处理的群聊命令时触发"""
launcher_type: str
launcher_id: typing.Union[int, str]
sender_id: typing.Union[int, str]
command: str
params: list[str]
text_message: str
is_admin: bool
alter: typing.Optional[str] = None
"""修改后的完整命令文本"""
reply: typing.Optional[list] = None
"""回复消息组件列表"""
class NormalMessageResponded(BaseEventModel):
"""回复普通消息时触发"""
launcher_type: str
launcher_id: typing.Union[int, str]
sender_id: typing.Union[int, str]
session: provider_session.Session
"""会话对象"""
prefix: str
"""回复消息的前缀"""
response_text: str
"""回复消息的文本"""
finish_reason: str
"""响应结束原因"""
funcs_called: list[str]
"""调用的函数列表"""
reply: typing.Optional[list] = None
"""回复消息组件列表"""
class PromptPreProcessing(BaseEventModel):
"""会话中的Prompt预处理时触发"""
session_name: str
default_prompt: list[llm_entities.Message]
"""此对话的情景预设,可修改"""
prompt: list[llm_entities.Message]
"""此对话现有消息记录,可修改"""

View File

@@ -1,135 +0,0 @@
from __future__ import annotations
import typing
import pydantic
from pkg.provider import entities
import langbot_plugin.api.entities.builtin.platform.message as platform_message
class FunctionCall(pydantic.BaseModel):
name: str
arguments: str
class ToolCall(pydantic.BaseModel):
id: str
type: str
function: FunctionCall
class ImageURLContentObject(pydantic.BaseModel):
url: str
def __str__(self):
return self.url[:128] + ('...' if len(self.url) > 128 else '')
class ContentElement(pydantic.BaseModel):
type: str
"""内容类型"""
text: typing.Optional[str] = None
image_url: typing.Optional[ImageURLContentObject] = None
image_base64: typing.Optional[str] = None
def __str__(self):
if self.type == 'text':
return self.text
elif self.type == 'image_url':
return f'[图片]({self.image_url})'
else:
return '未知内容'
@classmethod
def from_text(cls, text: str):
return cls(type='text', text=text)
@classmethod
def from_image_url(cls, image_url: str):
return cls(type='image_url', image_url=ImageURLContentObject(url=image_url))
@classmethod
def from_image_base64(cls, image_base64: str):
return cls(type='image_base64', image_base64=image_base64)
class Message(pydantic.BaseModel):
"""消息"""
role: str # user, system, assistant, tool, command, plugin
"""消息的角色"""
name: typing.Optional[str] = None
"""名称,仅函数调用返回时设置"""
content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None
"""内容"""
tool_calls: typing.Optional[list[ToolCall]] = None
"""工具调用"""
tool_call_id: typing.Optional[str] = None
def readable_str(self) -> str:
if self.content is not None:
return str(self.role) + ': ' + str(self.get_content_platform_message_chain())
elif self.tool_calls is not None:
return f'调用工具: {self.tool_calls[0].id}'
else:
return '未知消息'
def get_content_platform_message_chain(self, prefix_text: str = '') -> platform_message.MessageChain | None:
"""将内容转换为平台消息 MessageChain 对象
Args:
prefix_text (str): 首个文字组件的前缀文本
"""
if self.content is None:
return None
elif isinstance(self.content, str):
return platform_message.MessageChain([platform_message.Plain(text=(prefix_text + self.content))])
elif isinstance(self.content, list):
mc = []
for ce in self.content:
if ce.type == 'text':
mc.append(platform_message.Plain(ce.text))
elif ce.type == 'image_url':
if ce.image_url.url.startswith('http'):
mc.append(platform_message.Image(url=ce.image_url.url))
else: # base64
b64_str = ce.image_url.url
if b64_str.startswith('data:'):
b64_str = b64_str.split(',')[1]
mc.append(platform_message.Image(base64=b64_str))
# 找第一个文字组件
if prefix_text:
for i, c in enumerate(mc):
if isinstance(c, platform_message.Plain):
mc[i] = platform_message.Plain(prefix_text + c.text)
break
else:
mc.insert(0, platform_message.Plain(prefix_text))
return platform_message.MessageChain(mc)
class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt"""
name: str
"""名称"""
messages: list[entities.Message]
"""消息列表"""

View File

@@ -4,11 +4,11 @@ import abc
import typing
from ...core import app
from .. import entities as llm_entities
from ...entity.persistence import model as persistence_model
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
from . import token
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
class RuntimeLLMModel:
@@ -58,10 +58,10 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
self,
query: pipeline_query.Query,
model: RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
messages: typing.List[provider_message.Message],
funcs: typing.List[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
) -> provider_message.Message:
"""调用API
Args:

View File

@@ -9,10 +9,10 @@ import httpx
from .. import errors, requester
from ... import entities as llm_entities
from ....utils import image
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
class AnthropicMessages(requester.LLMAPIRequester):
@@ -50,10 +50,10 @@ class AnthropicMessages(requester.LLMAPIRequester):
self,
query: pipeline_query.Query,
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
messages: typing.List[provider_message.Message],
funcs: typing.List[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
) -> provider_message.Message:
self.client.api_key = model.token_mgr.get_token()
args = extra_args.copy()
@@ -73,7 +73,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
if system_role_message:
messages.pop(i)
if isinstance(system_role_message, llm_entities.Message) and isinstance(system_role_message.content, str):
if isinstance(system_role_message, provider_message.Message) and isinstance(system_role_message.content, str):
args['system'] = system_role_message.content
req_messages = []
@@ -157,16 +157,16 @@ class AnthropicMessages(requester.LLMAPIRequester):
args['content'] += block.text
elif block.type == 'tool_use':
assert type(block) is anthropic.types.tool_use_block.ToolUseBlock
tool_call = llm_entities.ToolCall(
tool_call = provider_message.ToolCall(
id=block.id,
type='function',
function=llm_entities.FunctionCall(name=block.name, arguments=json.dumps(block.input)),
function=provider_message.FunctionCall(name=block.name, arguments=json.dumps(block.input)),
)
if 'tool_calls' not in args:
args['tool_calls'] = []
args['tool_calls'].append(tool_call)
return llm_entities.Message(**args)
return provider_message.Message(**args)
except anthropic.AuthenticationError as e:
raise errors.RequesterError(f'api-key 无效: {e.message}')
except anthropic.BadRequestError as e:

View File

@@ -8,9 +8,9 @@ import openai.types.chat.chat_completion as chat_completion
import httpx
from .. import errors, requester
from ... import entities as llm_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
class OpenAIChatCompletions(requester.LLMAPIRequester):
@@ -41,7 +41,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
async def _make_msg(
self,
chat_completion: chat_completion.ChatCompletion,
) -> llm_entities.Message:
) -> provider_message.Message:
chatcmpl_message = chat_completion.choices[0].message.model_dump()
# 确保 role 字段存在且不为 None
@@ -54,7 +54,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
if reasoning_content is not None:
chatcmpl_message['content'] = '<think>\n' + reasoning_content + '\n</think>\n' + chatcmpl_message['content']
message = llm_entities.Message(**chatcmpl_message)
message = provider_message.Message(**chatcmpl_message)
return message
@@ -65,7 +65,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
use_model: requester.RuntimeLLMModel,
use_funcs: list[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
) -> provider_message.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = {}
@@ -103,10 +103,10 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
self,
query: pipeline_query.Query,
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
messages: typing.List[provider_message.Message],
funcs: typing.List[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
) -> provider_message.Message:
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
for m in messages:
msg_dict = m.dict(exclude_none=True)

View File

@@ -4,9 +4,9 @@ import typing
from . import chatcmpl
from .. import errors, requester
from ... import entities as llm_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
@@ -24,7 +24,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
use_model: requester.RuntimeLLMModel,
use_funcs: list[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
) -> provider_message.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = {}

View File

@@ -5,9 +5,9 @@ import typing
from . import chatcmpl
from .. import requester
from ... import entities as llm_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
@@ -25,7 +25,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
use_model: requester.RuntimeLLMModel,
use_funcs: list[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
) -> provider_message.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = {}

View File

@@ -9,9 +9,9 @@ import openai.types.chat.chat_completion_message_tool_call as chat_completion_me
import httpx
from .. import entities, errors, requester
from ... import entities as llm_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
class ModelScopeChatCompletions(requester.LLMAPIRequester):
@@ -112,14 +112,14 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
async def _make_msg(
self,
chat_completion: chat_completion.ChatCompletion,
) -> llm_entities.Message:
) -> provider_message.Message:
chatcmpl_message = chat_completion.choices[0].message.dict()
# 确保 role 字段存在且不为 None
if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None:
chatcmpl_message['role'] = 'assistant'
message = llm_entities.Message(**chatcmpl_message)
message = provider_message.Message(**chatcmpl_message)
return message
@@ -130,7 +130,7 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
use_model: requester.RuntimeLLMModel,
use_funcs: list[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
) -> provider_message.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = {}
@@ -168,10 +168,10 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
self,
query: pipeline_query.Query,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
messages: typing.List[provider_message.Message],
funcs: typing.List[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
) -> provider_message.Message:
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
for m in messages:
msg_dict = m.dict(exclude_none=True)

View File

@@ -5,9 +5,9 @@ import typing
from . import chatcmpl
from .. import requester
from ... import entities as llm_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
@@ -25,7 +25,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
use_model: requester.RuntimeLLMModel,
use_funcs: list[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
) -> provider_message.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = {}

View File

@@ -10,9 +10,9 @@ import json
import ollama
from .. import errors, requester
from ... import entities as llm_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
REQUESTER_NAME: str = 'ollama-chat'
@@ -44,7 +44,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
use_model: requester.RuntimeLLMModel,
use_funcs: list[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
) -> provider_message.Message:
args = extra_args.copy()
args['model'] = use_model.model_entity.name
@@ -73,27 +73,27 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
args['tools'] = tools
resp = await self._req(args)
message: llm_entities.Message = await self._make_msg(resp)
message: provider_message.Message = await self._make_msg(resp)
return message
async def _make_msg(self, chat_completions: ollama.ChatResponse) -> llm_entities.Message:
async def _make_msg(self, chat_completions: ollama.ChatResponse) -> provider_message.Message:
message: ollama.Message = chat_completions.message
if message is None:
raise ValueError("chat_completions must contain a 'message' field")
ret_msg: llm_entities.Message = None
ret_msg: provider_message.Message = None
if message.content is not None:
ret_msg = llm_entities.Message(role='assistant', content=message.content)
ret_msg = provider_message.Message(role='assistant', content=message.content)
if message.tool_calls is not None and len(message.tool_calls) > 0:
tool_calls: list[llm_entities.ToolCall] = []
tool_calls: list[provider_message.ToolCall] = []
for tool_call in message.tool_calls:
tool_calls.append(
llm_entities.ToolCall(
provider_message.ToolCall(
id=uuid.uuid4().hex,
type='function',
function=llm_entities.FunctionCall(
function=provider_message.FunctionCall(
name=tool_call.function.name,
arguments=json.dumps(tool_call.function.arguments),
),
@@ -107,10 +107,10 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
self,
query: pipeline_query.Query,
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
messages: typing.List[provider_message.Message],
funcs: typing.List[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
) -> provider_message.Message:
req_messages: list = []
for m in messages:
msg_dict: dict = m.dict(exclude_none=True)

View File

@@ -4,7 +4,7 @@ import abc
import typing
from ..core import app
from . import entities as llm_entities
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@@ -36,6 +36,6 @@ class RequestRunner(abc.ABC):
self.pipeline_config = pipeline_config
@abc.abstractmethod
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
"""运行请求"""
pass

View File

@@ -7,8 +7,8 @@ import dashscope
from .. import runner
from ...core import app
from .. import entities as llm_entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
class DashscopeAPIError(Exception):
@@ -90,7 +90,9 @@ class DashScopeAPIRunner(runner.RequestRunner):
return plain_text, image_ids
async def _agent_messages(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def _agent_messages(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""Dashscope 智能体对话请求"""
# 局部变量
@@ -143,14 +145,14 @@ class DashScopeAPIRunner(runner.RequestRunner):
# 将参考资料替换到文本中
pending_content = self._replace_references(pending_content, references_dict)
yield llm_entities.Message(
yield provider_message.Message(
role='assistant',
content=pending_content,
)
async def _workflow_messages(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""Dashscope 工作流对话请求"""
# 局部变量
@@ -208,12 +210,12 @@ class DashScopeAPIRunner(runner.RequestRunner):
# 将参考资料替换到文本中
pending_content = self._replace_references(pending_content, references_dict)
yield llm_entities.Message(
yield provider_message.Message(
role='assistant',
content=pending_content,
)
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
"""运行"""
if self.app_type == 'agent':
async for msg in self._agent_messages(query):

View File

@@ -9,7 +9,7 @@ import base64
from .. import runner
from ...core import app
from .. import entities as llm_entities
import langbot_plugin.api.entities.builtin.provider.message as provider_message
from ...utils import image
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from libs.dify_service_api.v1 import client, errors
@@ -90,7 +90,9 @@ class DifyServiceAPIRunner(runner.RequestRunner):
return plain_text, image_ids
async def _chat_messages(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def _chat_messages(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""调用聊天助手"""
cov_id = query.session.using_conversation.uuid or ''
query.variables['conversation_id'] = cov_id
@@ -132,7 +134,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
if mode == 'workflow':
if chunk['event'] == 'node_finished':
if chunk['data']['node_type'] == 'answer':
yield llm_entities.Message(
yield provider_message.Message(
role='assistant',
content=self._try_convert_thinking(chunk['data']['outputs']['answer']),
)
@@ -140,7 +142,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
if chunk['event'] == 'message':
basic_mode_pending_chunk += chunk['answer']
elif chunk['event'] == 'message_end':
yield llm_entities.Message(
yield provider_message.Message(
role='assistant',
content=self._try_convert_thinking(basic_mode_pending_chunk),
)
@@ -153,7 +155,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
async def _agent_chat_messages(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""调用聊天助手"""
cov_id = query.session.using_conversation.uuid or ''
query.variables['conversation_id'] = cov_id
@@ -198,7 +200,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
else:
if pending_agent_message.strip() != '':
pending_agent_message = pending_agent_message.replace('</details>Action:', '</details>')
yield llm_entities.Message(
yield provider_message.Message(
role='assistant',
content=self._try_convert_thinking(pending_agent_message),
)
@@ -209,13 +211,13 @@ class DifyServiceAPIRunner(runner.RequestRunner):
continue
if chunk['tool']:
msg = llm_entities.Message(
msg = provider_message.Message(
role='assistant',
tool_calls=[
llm_entities.ToolCall(
provider_message.ToolCall(
id=chunk['id'],
type='function',
function=llm_entities.FunctionCall(
function=provider_message.FunctionCall(
name=chunk['tool'],
arguments=json.dumps({}),
),
@@ -232,9 +234,9 @@ class DifyServiceAPIRunner(runner.RequestRunner):
image_url = base_url + chunk['url']
yield llm_entities.Message(
yield provider_message.Message(
role='assistant',
content=[llm_entities.ContentElement.from_image_url(image_url)],
content=[provider_message.ContentElement.from_image_url(image_url)],
)
if chunk['event'] == 'error':
raise errors.DifyAPIError('dify 服务错误: ' + chunk['message'])
@@ -246,7 +248,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
async def _workflow_messages(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""调用工作流"""
if not query.session.using_conversation.uuid:
@@ -290,14 +292,14 @@ class DifyServiceAPIRunner(runner.RequestRunner):
if chunk['data']['node_type'] == 'start' or chunk['data']['node_type'] == 'end':
continue
msg = llm_entities.Message(
msg = provider_message.Message(
role='assistant',
content=None,
tool_calls=[
llm_entities.ToolCall(
provider_message.ToolCall(
id=chunk['data']['node_id'],
type='function',
function=llm_entities.FunctionCall(
function=provider_message.FunctionCall(
name=chunk['data']['title'],
arguments=json.dumps({}),
),
@@ -311,14 +313,14 @@ class DifyServiceAPIRunner(runner.RequestRunner):
if chunk['data']['error']:
raise errors.DifyAPIError(chunk['data']['error'])
msg = llm_entities.Message(
msg = provider_message.Message(
role='assistant',
content=chunk['data']['outputs']['summary'],
)
yield msg
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
"""运行请求"""
if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat':
async for msg in self._chat_messages(query):

View File

@@ -4,15 +4,15 @@ import json
import typing
from .. import runner
from .. import entities as llm_entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
@runner.runner_class('local-agent')
class LocalAgentRunner(runner.RequestRunner):
"""本地Agent请求运行器"""
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
"""运行请求"""
pending_tool_calls = []
@@ -45,7 +45,7 @@ class LocalAgentRunner(runner.RequestRunner):
func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters)
msg = llm_entities.Message(
msg = provider_message.Message(
role='tool',
content=json.dumps(func_ret, ensure_ascii=False),
tool_call_id=tool_call.id,
@@ -56,7 +56,7 @@ class LocalAgentRunner(runner.RequestRunner):
req_messages.append(msg)
except Exception as e:
# 工具调用出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(role='tool', content=f'err: {e}', tool_call_id=tool_call.id)
err_msg = provider_message.Message(role='tool', content=f'err: {e}', tool_call_id=tool_call.id)
yield err_msg

View File

@@ -7,8 +7,8 @@ import aiohttp
from .. import runner
from ...core import app
from .. import entities as llm_entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
class N8nAPIError(Exception):
@@ -68,7 +68,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
return plain_text
async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
"""调用n8n webhook"""
# 生成会话ID如果不存在
if not query.session.using_conversation.uuid:
@@ -146,7 +146,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
output_content = json.dumps(response_data, ensure_ascii=False)
# 返回消息
yield llm_entities.Message(
yield provider_message.Message(
role='assistant',
content=output_content,
)
@@ -154,7 +154,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
self.ap.logger.error(f'n8n webhook call exception: {str(e)}')
raise N8nAPIError(f'n8n webhook call exception: {str(e)}')
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
"""运行请求"""
async for msg in self._call_webhook(query):
yield msg