mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-11 16:26:02 +00:00
style: restrict line-length
This commit is contained in:
@@ -80,17 +80,13 @@ class Message(pydantic.BaseModel):
|
||||
|
||||
def readable_str(self) -> str:
|
||||
if self.content is not None:
|
||||
return (
|
||||
str(self.role) + ': ' + str(self.get_content_platform_message_chain())
|
||||
)
|
||||
return str(self.role) + ': ' + str(self.get_content_platform_message_chain())
|
||||
elif self.tool_calls is not None:
|
||||
return f'调用工具: {self.tool_calls[0].id}'
|
||||
else:
|
||||
return '未知消息'
|
||||
|
||||
def get_content_platform_message_chain(
|
||||
self, prefix_text: str = ''
|
||||
) -> platform_message.MessageChain | None:
|
||||
def get_content_platform_message_chain(self, prefix_text: str = '') -> platform_message.MessageChain | None:
|
||||
"""将内容转换为平台消息 MessageChain 对象
|
||||
|
||||
Args:
|
||||
@@ -100,9 +96,7 @@ class Message(pydantic.BaseModel):
|
||||
if self.content is None:
|
||||
return None
|
||||
elif isinstance(self.content, str):
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Plain(prefix_text + self.content)]
|
||||
)
|
||||
return platform_message.MessageChain([platform_message.Plain(prefix_text + self.content)])
|
||||
elif isinstance(self.content, list):
|
||||
mc = []
|
||||
for ce in self.content:
|
||||
|
||||
@@ -43,16 +43,12 @@ class ModelManager:
|
||||
self.requester_dict = {}
|
||||
|
||||
async def initialize(self):
|
||||
self.requester_components = self.ap.discover.get_components_by_kind(
|
||||
'LLMAPIRequester'
|
||||
)
|
||||
self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester')
|
||||
|
||||
# forge requester class dict
|
||||
requester_dict: dict[str, type[requester.LLMAPIRequester]] = {}
|
||||
for component in self.requester_components:
|
||||
requester_dict[component.metadata.name] = (
|
||||
component.get_python_component_class()
|
||||
)
|
||||
requester_dict[component.metadata.name] = component.get_python_component_class()
|
||||
|
||||
self.requester_dict = requester_dict
|
||||
|
||||
@@ -65,9 +61,7 @@ class ModelManager:
|
||||
self.llm_models = []
|
||||
|
||||
# llm models
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.LLMModel)
|
||||
)
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
|
||||
|
||||
llm_models = result.all()
|
||||
|
||||
@@ -77,9 +71,7 @@ class ModelManager:
|
||||
|
||||
async def load_llm_model(
|
||||
self,
|
||||
model_info: persistence_model.LLMModel
|
||||
| sqlalchemy.Row[persistence_model.LLMModel]
|
||||
| dict,
|
||||
model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict,
|
||||
):
|
||||
"""加载模型"""
|
||||
|
||||
@@ -88,9 +80,7 @@ class ModelManager:
|
||||
elif isinstance(model_info, dict):
|
||||
model_info = persistence_model.LLMModel(**model_info)
|
||||
|
||||
requester_inst = self.requester_dict[model_info.requester](
|
||||
ap=self.ap, config=model_info.requester_config
|
||||
)
|
||||
requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config)
|
||||
|
||||
await requester_inst.initialize()
|
||||
|
||||
@@ -136,9 +126,7 @@ class ModelManager:
|
||||
return component.to_plain_dict()
|
||||
return None
|
||||
|
||||
def get_available_requester_manifest_by_name(
|
||||
self, name: str
|
||||
) -> engine.Component | None:
|
||||
def get_available_requester_manifest_by_name(self, name: str) -> engine.Component | None:
|
||||
"""通过名称获取请求器清单"""
|
||||
for component in self.requester_components:
|
||||
if component.metadata.name == name:
|
||||
|
||||
@@ -73,9 +73,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
|
||||
if system_role_message:
|
||||
messages.pop(i)
|
||||
|
||||
if isinstance(system_role_message, llm_entities.Message) and isinstance(
|
||||
system_role_message.content, str
|
||||
):
|
||||
if isinstance(system_role_message, llm_entities.Message) and isinstance(system_role_message.content, str):
|
||||
args['system'] = system_role_message.content
|
||||
|
||||
req_messages = []
|
||||
@@ -106,9 +104,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
|
||||
elif isinstance(m.content, list):
|
||||
for i, ce in enumerate(m.content):
|
||||
if ce.type == 'image_base64':
|
||||
image_b64, image_format = await image.extract_b64_and_format(
|
||||
ce.image_base64
|
||||
)
|
||||
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
||||
|
||||
alter_image_ele = {
|
||||
'type': 'image',
|
||||
@@ -156,9 +152,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
|
||||
|
||||
for block in resp.content:
|
||||
if block.type == 'thinking':
|
||||
args['content'] = (
|
||||
'<think>' + block.thinking + '</think>\n' + args['content']
|
||||
)
|
||||
args['content'] = '<think>' + block.thinking + '</think>\n' + args['content']
|
||||
elif block.type == 'text':
|
||||
args['content'] += block.text
|
||||
elif block.type == 'tool_use':
|
||||
@@ -166,9 +160,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
|
||||
tool_call = llm_entities.ToolCall(
|
||||
id=block.id,
|
||||
type='function',
|
||||
function=llm_entities.FunctionCall(
|
||||
name=block.name, arguments=json.dumps(block.input)
|
||||
),
|
||||
function=llm_entities.FunctionCall(name=block.name, arguments=json.dumps(block.input)),
|
||||
)
|
||||
if 'tool_calls' not in args:
|
||||
args['tool_calls'] = []
|
||||
|
||||
@@ -28,9 +28,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
api_key='',
|
||||
base_url=self.requester_cfg['base_url'].replace(' ', ''),
|
||||
timeout=self.requester_cfg['timeout'],
|
||||
http_client=httpx.AsyncClient(
|
||||
trust_env=True, timeout=self.requester_cfg['timeout']
|
||||
),
|
||||
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),
|
||||
)
|
||||
|
||||
async def _req(
|
||||
@@ -50,20 +48,11 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None:
|
||||
chatcmpl_message['role'] = 'assistant'
|
||||
|
||||
reasoning_content = (
|
||||
chatcmpl_message['reasoning_content']
|
||||
if 'reasoning_content' in chatcmpl_message
|
||||
else None
|
||||
)
|
||||
reasoning_content = chatcmpl_message['reasoning_content'] if 'reasoning_content' in chatcmpl_message else None
|
||||
|
||||
# deepseek的reasoner模型
|
||||
if reasoning_content is not None:
|
||||
chatcmpl_message['content'] = (
|
||||
'<think>\n'
|
||||
+ reasoning_content
|
||||
+ '\n</think>\n'
|
||||
+ chatcmpl_message['content']
|
||||
)
|
||||
chatcmpl_message['content'] = '<think>\n' + reasoning_content + '\n</think>\n' + chatcmpl_message['content']
|
||||
|
||||
message = llm_entities.Message(**chatcmpl_message)
|
||||
|
||||
@@ -124,10 +113,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
content = msg_dict.get('content')
|
||||
if isinstance(content, list):
|
||||
# 检查 content 列表中是否每个部分都是文本
|
||||
if all(
|
||||
isinstance(part, dict) and part.get('type') == 'text'
|
||||
for part in content
|
||||
):
|
||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
||||
# 将所有文本部分合并为一个字符串
|
||||
msg_dict['content'] = '\n'.join(part['text'] for part in content)
|
||||
req_messages.append(msg_dict)
|
||||
|
||||
@@ -1,23 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
import json
|
||||
import base64
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import openai
|
||||
import openai.types.chat.chat_completion as chat_completion
|
||||
import openai.types.chat.chat_completion_message_tool_call as chat_completion_message_tool_call
|
||||
import httpx
|
||||
import aiohttp
|
||||
import async_lru
|
||||
|
||||
from .. import entities, errors, requester
|
||||
from ....core import entities as core_entities, app
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
from ....utils import image
|
||||
|
||||
|
||||
class ModelScopeChatCompletions(requester.LLMAPIRequester):
|
||||
@@ -33,26 +27,22 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
|
||||
self.requester_cfg = self.ap.provider_cfg.data['requester']['modelscope-chat-completions']
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
self.client = openai.AsyncClient(
|
||||
api_key="",
|
||||
api_key='',
|
||||
base_url=self.requester_cfg['base-url'],
|
||||
timeout=self.requester_cfg['timeout'],
|
||||
http_client=httpx.AsyncClient(
|
||||
trust_env=True,
|
||||
timeout=self.requester_cfg['timeout']
|
||||
)
|
||||
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),
|
||||
)
|
||||
|
||||
async def _req(
|
||||
self,
|
||||
args: dict,
|
||||
) -> chat_completion.ChatCompletion:
|
||||
args["stream"] = True
|
||||
args['stream'] = True
|
||||
|
||||
chunk = None
|
||||
|
||||
pending_content = ""
|
||||
pending_content = ''
|
||||
|
||||
tool_calls = []
|
||||
|
||||
@@ -74,7 +64,7 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
|
||||
break
|
||||
else:
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
break
|
||||
|
||||
@@ -82,36 +72,41 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
|
||||
|
||||
for tc in tool_calls:
|
||||
function = chat_completion_message_tool_call.Function(
|
||||
name=tc.function.name,
|
||||
arguments=tc.function.arguments
|
||||
name=tc.function.name, arguments=tc.function.arguments
|
||||
)
|
||||
real_tool_calls.append(chat_completion_message_tool_call.ChatCompletionMessageToolCall(
|
||||
id=tc.id,
|
||||
function=function,
|
||||
type="function"
|
||||
))
|
||||
|
||||
return chat_completion.ChatCompletion(
|
||||
id=chunk.id,
|
||||
object="chat.completion",
|
||||
created=chunk.created,
|
||||
choices=[
|
||||
chat_completion.Choice(
|
||||
index=0,
|
||||
message=chat_completion.ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=pending_content,
|
||||
tool_calls=real_tool_calls if len(real_tool_calls) > 0 else None
|
||||
),
|
||||
finish_reason=chunk.choices[0].finish_reason if hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason is not None else 'stop',
|
||||
logprobs=chunk.choices[0].logprobs,
|
||||
real_tool_calls.append(
|
||||
chat_completion_message_tool_call.ChatCompletionMessageToolCall(
|
||||
id=tc.id, function=function, type='function'
|
||||
)
|
||||
],
|
||||
model=chunk.model,
|
||||
service_tier=chunk.service_tier if hasattr(chunk, 'service_tier') else None,
|
||||
system_fingerprint=chunk.system_fingerprint if hasattr(chunk, 'system_fingerprint') else None,
|
||||
usage=chunk.usage if hasattr(chunk, 'usage') else None
|
||||
) if chunk else None
|
||||
)
|
||||
|
||||
return (
|
||||
chat_completion.ChatCompletion(
|
||||
id=chunk.id,
|
||||
object='chat.completion',
|
||||
created=chunk.created,
|
||||
choices=[
|
||||
chat_completion.Choice(
|
||||
index=0,
|
||||
message=chat_completion.ChatCompletionMessage(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
tool_calls=real_tool_calls if len(real_tool_calls) > 0 else None,
|
||||
),
|
||||
finish_reason=chunk.choices[0].finish_reason
|
||||
if hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason is not None
|
||||
else 'stop',
|
||||
logprobs=chunk.choices[0].logprobs,
|
||||
)
|
||||
],
|
||||
model=chunk.model,
|
||||
service_tier=chunk.service_tier if hasattr(chunk, 'service_tier') else None,
|
||||
system_fingerprint=chunk.system_fingerprint if hasattr(chunk, 'system_fingerprint') else None,
|
||||
usage=chunk.usage if hasattr(chunk, 'usage') else None,
|
||||
)
|
||||
if chunk
|
||||
else None
|
||||
)
|
||||
return await self.client.chat.completions.create(**args)
|
||||
|
||||
async def _make_msg(
|
||||
@@ -138,29 +133,27 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
args['model'] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
args['tools'] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages.copy()
|
||||
|
||||
# 检查vision
|
||||
for msg in messages:
|
||||
if 'content' in msg and isinstance(msg["content"], list):
|
||||
for me in msg["content"]:
|
||||
if me["type"] == "image_base64":
|
||||
me["image_url"] = {
|
||||
"url": me["image_base64"]
|
||||
}
|
||||
me["type"] = "image_url"
|
||||
del me["image_base64"]
|
||||
if 'content' in msg and isinstance(msg['content'], list):
|
||||
for me in msg['content']:
|
||||
if me['type'] == 'image_base64':
|
||||
me['image_url'] = {'url': me['image_base64']}
|
||||
me['type'] = 'image_url'
|
||||
del me['image_base64']
|
||||
|
||||
args["messages"] = messages
|
||||
args['messages'] = messages
|
||||
|
||||
# 发送请求
|
||||
resp = await self._req(args)
|
||||
@@ -180,12 +173,12 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
|
||||
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
|
||||
for m in messages:
|
||||
msg_dict = m.dict(exclude_none=True)
|
||||
content = msg_dict.get("content")
|
||||
content = msg_dict.get('content')
|
||||
if isinstance(content, list):
|
||||
# 检查 content 列表中是否每个部分都是文本
|
||||
if all(isinstance(part, dict) and part.get("type") == "text" for part in content):
|
||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
||||
# 将所有文本部分合并为一个字符串
|
||||
msg_dict["content"] = "\n".join(part["text"] for part in content)
|
||||
msg_dict['content'] = '\n'.join(part['text'] for part in content)
|
||||
req_messages.append(msg_dict)
|
||||
|
||||
try:
|
||||
@@ -204,4 +197,4 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
|
||||
except openai.RateLimitError as e:
|
||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
||||
except openai.APIError as e:
|
||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
||||
|
||||
@@ -61,13 +61,9 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
|
||||
msg['content'] = '\n'.join(text_content)
|
||||
msg['images'] = [url.split(',')[1] for url in image_urls]
|
||||
if (
|
||||
'tool_calls' in msg
|
||||
): # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
|
||||
if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
|
||||
for tool_call in msg['tool_calls']:
|
||||
tool_call['function']['arguments'] = json.loads(
|
||||
tool_call['function']['arguments']
|
||||
)
|
||||
tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments'])
|
||||
args['messages'] = messages
|
||||
|
||||
args['tools'] = []
|
||||
@@ -80,9 +76,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
message: llm_entities.Message = await self._make_msg(resp)
|
||||
return message
|
||||
|
||||
async def _make_msg(
|
||||
self, chat_completions: ollama.ChatResponse
|
||||
) -> llm_entities.Message:
|
||||
async def _make_msg(self, chat_completions: ollama.ChatResponse) -> llm_entities.Message:
|
||||
message: ollama.Message = chat_completions.message
|
||||
if message is None:
|
||||
raise ValueError("chat_completions must contain a 'message' field")
|
||||
@@ -122,10 +116,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
msg_dict: dict = m.dict(exclude_none=True)
|
||||
content: Any = msg_dict.get('content')
|
||||
if isinstance(content, list):
|
||||
if all(
|
||||
isinstance(part, dict) and part.get('type') == 'text'
|
||||
for part in content
|
||||
):
|
||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
||||
msg_dict['content'] = '\n'.join(part['text'] for part in content)
|
||||
req_messages.append(msg_dict)
|
||||
try:
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import openai
|
||||
|
||||
from . import chatcmpl, modelscopechatcmpl
|
||||
from .. import requester
|
||||
from . import chatcmpl
|
||||
from ....core import app
|
||||
|
||||
|
||||
class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
"""欧派云 ChatCompletion API 请求器"""
|
||||
|
||||
@@ -17,4 +16,4 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
self.requester_cfg = self.ap.provider_cfg.data['requester']['ppio-chat-completions']
|
||||
self.requester_cfg = self.ap.provider_cfg.data['requester']['ppio-chat-completions']
|
||||
|
||||
@@ -35,8 +35,6 @@ class RequestRunner(abc.ABC):
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
@abc.abstractmethod
|
||||
async def run(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求"""
|
||||
pass
|
||||
|
||||
@@ -26,7 +26,9 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
app_type: str # 应用类型
|
||||
app_id: str # 应用ID
|
||||
api_key: str # API Key
|
||||
references_quote: str # 引用资料提示(当展示回答来源功能开启时,这个变量会作为引用资料名前的提示,可在provider.json中配置)
|
||||
references_quote: (
|
||||
str # 引用资料提示(当展示回答来源功能开启时,这个变量会作为引用资料名前的提示,可在provider.json中配置)
|
||||
)
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
"""初始化"""
|
||||
@@ -42,9 +44,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
# 初始化Dashscope 参数配置
|
||||
self.app_id = self.pipeline_config['ai']['dashscope-app-api']['app-id']
|
||||
self.api_key = self.pipeline_config['ai']['dashscope-app-api']['api-key']
|
||||
self.references_quote = self.pipeline_config['ai']['dashscope-app-api'][
|
||||
'references_quote'
|
||||
]
|
||||
self.references_quote = self.pipeline_config['ai']['dashscope-app-api']['references_quote']
|
||||
|
||||
def _replace_references(self, text, references_dict):
|
||||
"""阿里云百炼平台的自定义应用支持资料引用,此函数可以将引用标签替换为参考资料"""
|
||||
@@ -65,9 +65,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
# 使用 re.sub() 进行替换
|
||||
return pattern.sub(replacement, text)
|
||||
|
||||
async def _preprocess_user_message(
|
||||
self, query: core_entities.Query
|
||||
) -> tuple[str, list[str]]:
|
||||
async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]:
|
||||
"""预处理用户消息,提取纯文本,阿里云提供的上传文件方法过于复杂,暂不支持上传文件(包括图片)"""
|
||||
plain_text = ''
|
||||
image_ids = []
|
||||
@@ -91,9 +89,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
|
||||
return plain_text, image_ids
|
||||
|
||||
async def _agent_messages(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
async def _agent_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""Dashscope 智能体对话请求"""
|
||||
|
||||
# 局部变量
|
||||
@@ -151,9 +147,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
content=pending_content,
|
||||
)
|
||||
|
||||
async def _workflow_messages(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""Dashscope 工作流对话请求"""
|
||||
|
||||
# 局部变量
|
||||
@@ -216,9 +210,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
content=pending_content,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行"""
|
||||
if self.app_type == 'agent':
|
||||
async for msg in self._agent_messages(query):
|
||||
|
||||
@@ -26,10 +26,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
valid_app_types = ['chat', 'agent', 'workflow']
|
||||
if (
|
||||
self.pipeline_config['ai']['dify-service-api']['app-type']
|
||||
not in valid_app_types
|
||||
):
|
||||
if self.pipeline_config['ai']['dify-service-api']['app-type'] not in valid_app_types:
|
||||
raise errors.DifyAPIError(
|
||||
f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}'
|
||||
)
|
||||
@@ -48,16 +45,10 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
):
|
||||
return resp_text
|
||||
|
||||
if (
|
||||
self.pipeline_config['ai']['dify-service-api']['thinking-convert']
|
||||
== 'original'
|
||||
):
|
||||
if self.pipeline_config['ai']['dify-service-api']['thinking-convert'] == 'original':
|
||||
return resp_text
|
||||
|
||||
if (
|
||||
self.pipeline_config['ai']['dify-service-api']['thinking-convert']
|
||||
== 'remove'
|
||||
):
|
||||
if self.pipeline_config['ai']['dify-service-api']['thinking-convert'] == 'remove':
|
||||
return re.sub(
|
||||
r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>.*?</details>',
|
||||
'',
|
||||
@@ -65,18 +56,13 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
|
||||
if (
|
||||
self.pipeline_config['ai']['dify-service-api']['thinking-convert']
|
||||
== 'plain'
|
||||
):
|
||||
if self.pipeline_config['ai']['dify-service-api']['thinking-convert'] == 'plain':
|
||||
pattern = r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>(.*?)</details>'
|
||||
thinking_text = re.search(pattern, resp_text, flags=re.DOTALL)
|
||||
content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL)
|
||||
return f'<think>{thinking_text.group(1)}</think>\n{content_text}'
|
||||
|
||||
async def _preprocess_user_message(
|
||||
self, query: core_entities.Query
|
||||
) -> tuple[str, list[str]]:
|
||||
async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]:
|
||||
"""预处理用户消息,提取纯文本,并将图片上传到 Dify 服务
|
||||
|
||||
Returns:
|
||||
@@ -90,9 +76,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
if ce.type == 'text':
|
||||
plain_text += ce.text
|
||||
elif ce.type == 'image_base64':
|
||||
image_b64, image_format = await image.extract_b64_and_format(
|
||||
ce.image_base64
|
||||
)
|
||||
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
||||
file_bytes = base64.b64decode(image_b64)
|
||||
file = ('img.png', file_bytes, f'image/{image_format}')
|
||||
file_upload_resp = await self.dify_client.upload_file(
|
||||
@@ -106,9 +90,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
return plain_text, image_ids
|
||||
|
||||
async def _chat_messages(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
async def _chat_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""调用聊天助手"""
|
||||
cov_id = query.session.using_conversation.uuid or ''
|
||||
|
||||
@@ -151,9 +133,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
if chunk['data']['node_type'] == 'answer':
|
||||
yield llm_entities.Message(
|
||||
role='assistant',
|
||||
content=self._try_convert_thinking(
|
||||
chunk['data']['outputs']['answer']
|
||||
),
|
||||
content=self._try_convert_thinking(chunk['data']['outputs']['answer']),
|
||||
)
|
||||
elif mode == 'basic':
|
||||
if chunk['event'] == 'message':
|
||||
@@ -166,9 +146,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
basic_mode_pending_chunk = ''
|
||||
|
||||
if chunk is None:
|
||||
raise errors.DifyAPIError(
|
||||
'Dify API 没有返回任何响应,请检查网络连接和API配置'
|
||||
)
|
||||
raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置')
|
||||
|
||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||
|
||||
@@ -217,9 +195,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
pending_agent_message += chunk['answer']
|
||||
else:
|
||||
if pending_agent_message.strip() != '':
|
||||
pending_agent_message = pending_agent_message.replace(
|
||||
'</details>Action:', '</details>'
|
||||
)
|
||||
pending_agent_message = pending_agent_message.replace('</details>Action:', '</details>')
|
||||
yield llm_entities.Message(
|
||||
role='assistant',
|
||||
content=self._try_convert_thinking(pending_agent_message),
|
||||
@@ -227,9 +203,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
pending_agent_message = ''
|
||||
|
||||
if chunk['event'] == 'agent_thought':
|
||||
if (
|
||||
chunk['tool'] != '' and chunk['observation'] != ''
|
||||
): # 工具调用结果,跳过
|
||||
if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过
|
||||
continue
|
||||
|
||||
if chunk['tool']:
|
||||
@@ -258,23 +232,17 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
yield llm_entities.Message(
|
||||
role='assistant',
|
||||
content=[
|
||||
llm_entities.ContentElement.from_image_url(image_url)
|
||||
],
|
||||
content=[llm_entities.ContentElement.from_image_url(image_url)],
|
||||
)
|
||||
if chunk['event'] == 'error':
|
||||
raise errors.DifyAPIError('dify 服务错误: ' + chunk['message'])
|
||||
|
||||
if chunk is None:
|
||||
raise errors.DifyAPIError(
|
||||
'Dify API 没有返回任何响应,请检查网络连接和API配置'
|
||||
)
|
||||
raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置')
|
||||
|
||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||
|
||||
async def _workflow_messages(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""调用工作流"""
|
||||
|
||||
if not query.session.using_conversation.uuid:
|
||||
@@ -315,10 +283,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
continue
|
||||
|
||||
if chunk['event'] == 'node_started':
|
||||
if (
|
||||
chunk['data']['node_type'] == 'start'
|
||||
or chunk['data']['node_type'] == 'end'
|
||||
):
|
||||
if chunk['data']['node_type'] == 'start' or chunk['data']['node_type'] == 'end':
|
||||
continue
|
||||
|
||||
msg = llm_entities.Message(
|
||||
@@ -349,9 +314,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
yield msg
|
||||
|
||||
async def run(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求"""
|
||||
if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat':
|
||||
async for msg in self._chat_messages(query):
|
||||
|
||||
@@ -12,15 +12,11 @@ from .. import entities as llm_entities
|
||||
class LocalAgentRunner(runner.RequestRunner):
|
||||
"""本地Agent请求运行器"""
|
||||
|
||||
async def run(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求"""
|
||||
pending_tool_calls = []
|
||||
|
||||
req_messages = (
|
||||
query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
||||
)
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
||||
|
||||
# 首次请求
|
||||
msg = await query.use_llm_model.requester.invoke_llm(
|
||||
@@ -45,9 +41,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
|
||||
parameters = json.loads(func.arguments)
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(
|
||||
query, func.name, parameters
|
||||
)
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(query, func.name, parameters)
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role='tool',
|
||||
@@ -60,9 +54,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
req_messages.append(msg)
|
||||
except Exception as e:
|
||||
# 工具调用出错,添加一个报错信息到 req_messages
|
||||
err_msg = llm_entities.Message(
|
||||
role='tool', content=f'err: {e}', tool_call_id=tool_call.id
|
||||
)
|
||||
err_msg = llm_entities.Message(role='tool', content=f'err: {e}', tool_call_id=tool_call.id)
|
||||
|
||||
yield err_msg
|
||||
|
||||
|
||||
@@ -23,10 +23,7 @@ class SessionManager:
|
||||
async def get_session(self, query: core_entities.Query) -> core_entities.Session:
|
||||
"""获取会话"""
|
||||
for session in self.session_list:
|
||||
if (
|
||||
query.launcher_type == session.launcher_type
|
||||
and query.launcher_id == session.launcher_id
|
||||
):
|
||||
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
|
||||
return session
|
||||
|
||||
session_concurrency = self.ap.instance_config.data['concurrency']['session']
|
||||
|
||||
@@ -45,9 +45,7 @@ class ToolLoader(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def invoke_tool(
|
||||
self, query: core_entities.Query, name: str, parameters: dict
|
||||
) -> typing.Any:
|
||||
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
|
||||
"""执行工具调用"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -43,15 +43,11 @@ class RuntimeMCPSession:
|
||||
env=self.server_config['env'],
|
||||
)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
stdio_client(server_params)
|
||||
)
|
||||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||
|
||||
stdio, write = stdio_transport
|
||||
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(stdio, write)
|
||||
)
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
@@ -66,25 +62,19 @@ class RuntimeMCPSession:
|
||||
|
||||
sseio, write = sse_transport
|
||||
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(sseio, write)
|
||||
)
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(sseio, write))
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
async def initialize(self):
|
||||
self.ap.logger.debug(
|
||||
f'初始化 MCP 会话: {self.server_name} {self.server_config}'
|
||||
)
|
||||
self.ap.logger.debug(f'初始化 MCP 会话: {self.server_name} {self.server_config}')
|
||||
|
||||
if self.server_config['mode'] == 'stdio':
|
||||
await self._init_stdio_python_server()
|
||||
elif self.server_config['mode'] == 'sse':
|
||||
await self._init_sse_server()
|
||||
else:
|
||||
raise ValueError(
|
||||
f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}'
|
||||
)
|
||||
raise ValueError(f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}')
|
||||
|
||||
tools = await self.session.list_tools()
|
||||
|
||||
@@ -132,9 +122,7 @@ class MCPLoader(loader.ToolLoader):
|
||||
self._last_listed_functions = []
|
||||
|
||||
async def initialize(self):
|
||||
for server_config in self.ap.instance_config.data.get('mcp', {}).get(
|
||||
'servers', []
|
||||
):
|
||||
for server_config in self.ap.instance_config.data.get('mcp', {}).get('servers', []):
|
||||
if not server_config['enable']:
|
||||
continue
|
||||
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
|
||||
@@ -155,9 +143,7 @@ class MCPLoader(loader.ToolLoader):
|
||||
async def has_tool(self, name: str) -> bool:
|
||||
return name in [f.name for f in self._last_listed_functions]
|
||||
|
||||
async def invoke_tool(
|
||||
self, query: core_entities.Query, name: str, parameters: dict
|
||||
) -> typing.Any:
|
||||
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
|
||||
for server_name, session in self.sessions.items():
|
||||
for function in session.functions:
|
||||
if function.name == name:
|
||||
|
||||
@@ -48,9 +48,7 @@ class PluginToolLoader(loader.ToolLoader):
|
||||
return function, plugin.plugin_inst
|
||||
return None, None
|
||||
|
||||
async def invoke_tool(
|
||||
self, query: core_entities.Query, name: str, parameters: dict
|
||||
) -> typing.Any:
|
||||
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
|
||||
try:
|
||||
function, plugin = await self._get_function_and_plugin(name)
|
||||
if function is None:
|
||||
|
||||
@@ -28,9 +28,7 @@ class ToolManager:
|
||||
await loader_inst.initialize()
|
||||
self.loaders.append(loader_inst)
|
||||
|
||||
async def get_all_functions(
|
||||
self, plugin_enabled: bool = None
|
||||
) -> list[entities.LLMFunction]:
|
||||
async def get_all_functions(self, plugin_enabled: bool = None) -> list[entities.LLMFunction]:
|
||||
"""获取所有函数"""
|
||||
all_functions: list[entities.LLMFunction] = []
|
||||
|
||||
@@ -39,9 +37,7 @@ class ToolManager:
|
||||
|
||||
return all_functions
|
||||
|
||||
async def generate_tools_for_openai(
|
||||
self, use_funcs: list[entities.LLMFunction]
|
||||
) -> list:
|
||||
async def generate_tools_for_openai(self, use_funcs: list[entities.LLMFunction]) -> list:
|
||||
"""生成函数列表"""
|
||||
tools = []
|
||||
|
||||
@@ -58,9 +54,7 @@ class ToolManager:
|
||||
|
||||
return tools
|
||||
|
||||
async def generate_tools_for_anthropic(
|
||||
self, use_funcs: list[entities.LLMFunction]
|
||||
) -> list:
|
||||
async def generate_tools_for_anthropic(self, use_funcs: list[entities.LLMFunction]) -> list:
|
||||
"""为anthropic生成函数列表
|
||||
|
||||
e.g.
|
||||
@@ -95,9 +89,7 @@ class ToolManager:
|
||||
|
||||
return tools
|
||||
|
||||
async def execute_func_call(
|
||||
self, query: core_entities.Query, name: str, parameters: dict
|
||||
) -> typing.Any:
|
||||
async def execute_func_call(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
|
||||
"""执行函数调用"""
|
||||
|
||||
for loader in self.loaders:
|
||||
|
||||
Reference in New Issue
Block a user