style: restrict line-length

This commit is contained in:
Junyan Qin
2025-05-10 18:04:58 +08:00
parent b30016ed08
commit 055b389353
134 changed files with 1096 additions and 2595 deletions

View File

@@ -80,17 +80,13 @@ class Message(pydantic.BaseModel):
def readable_str(self) -> str:
if self.content is not None:
return (
str(self.role) + ': ' + str(self.get_content_platform_message_chain())
)
return str(self.role) + ': ' + str(self.get_content_platform_message_chain())
elif self.tool_calls is not None:
return f'调用工具: {self.tool_calls[0].id}'
else:
return '未知消息'
def get_content_platform_message_chain(
self, prefix_text: str = ''
) -> platform_message.MessageChain | None:
def get_content_platform_message_chain(self, prefix_text: str = '') -> platform_message.MessageChain | None:
"""将内容转换为平台消息 MessageChain 对象
Args:
@@ -100,9 +96,7 @@ class Message(pydantic.BaseModel):
if self.content is None:
return None
elif isinstance(self.content, str):
return platform_message.MessageChain(
[platform_message.Plain(prefix_text + self.content)]
)
return platform_message.MessageChain([platform_message.Plain(prefix_text + self.content)])
elif isinstance(self.content, list):
mc = []
for ce in self.content:

View File

@@ -43,16 +43,12 @@ class ModelManager:
self.requester_dict = {}
async def initialize(self):
self.requester_components = self.ap.discover.get_components_by_kind(
'LLMAPIRequester'
)
self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester')
# forge requester class dict
requester_dict: dict[str, type[requester.LLMAPIRequester]] = {}
for component in self.requester_components:
requester_dict[component.metadata.name] = (
component.get_python_component_class()
)
requester_dict[component.metadata.name] = component.get_python_component_class()
self.requester_dict = requester_dict
@@ -65,9 +61,7 @@ class ModelManager:
self.llm_models = []
# llm models
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel)
)
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
llm_models = result.all()
@@ -77,9 +71,7 @@ class ModelManager:
async def load_llm_model(
self,
model_info: persistence_model.LLMModel
| sqlalchemy.Row[persistence_model.LLMModel]
| dict,
model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict,
):
"""加载模型"""
@@ -88,9 +80,7 @@ class ModelManager:
elif isinstance(model_info, dict):
model_info = persistence_model.LLMModel(**model_info)
requester_inst = self.requester_dict[model_info.requester](
ap=self.ap, config=model_info.requester_config
)
requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config)
await requester_inst.initialize()
@@ -136,9 +126,7 @@ class ModelManager:
return component.to_plain_dict()
return None
def get_available_requester_manifest_by_name(
self, name: str
) -> engine.Component | None:
def get_available_requester_manifest_by_name(self, name: str) -> engine.Component | None:
"""通过名称获取请求器清单"""
for component in self.requester_components:
if component.metadata.name == name:

View File

@@ -73,9 +73,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
if system_role_message:
messages.pop(i)
if isinstance(system_role_message, llm_entities.Message) and isinstance(
system_role_message.content, str
):
if isinstance(system_role_message, llm_entities.Message) and isinstance(system_role_message.content, str):
args['system'] = system_role_message.content
req_messages = []
@@ -106,9 +104,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
elif isinstance(m.content, list):
for i, ce in enumerate(m.content):
if ce.type == 'image_base64':
image_b64, image_format = await image.extract_b64_and_format(
ce.image_base64
)
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
alter_image_ele = {
'type': 'image',
@@ -156,9 +152,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
for block in resp.content:
if block.type == 'thinking':
args['content'] = (
'<think>' + block.thinking + '</think>\n' + args['content']
)
args['content'] = '<think>' + block.thinking + '</think>\n' + args['content']
elif block.type == 'text':
args['content'] += block.text
elif block.type == 'tool_use':
@@ -166,9 +160,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
tool_call = llm_entities.ToolCall(
id=block.id,
type='function',
function=llm_entities.FunctionCall(
name=block.name, arguments=json.dumps(block.input)
),
function=llm_entities.FunctionCall(name=block.name, arguments=json.dumps(block.input)),
)
if 'tool_calls' not in args:
args['tool_calls'] = []

View File

@@ -28,9 +28,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
api_key='',
base_url=self.requester_cfg['base_url'].replace(' ', ''),
timeout=self.requester_cfg['timeout'],
http_client=httpx.AsyncClient(
trust_env=True, timeout=self.requester_cfg['timeout']
),
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),
)
async def _req(
@@ -50,20 +48,11 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None:
chatcmpl_message['role'] = 'assistant'
reasoning_content = (
chatcmpl_message['reasoning_content']
if 'reasoning_content' in chatcmpl_message
else None
)
reasoning_content = chatcmpl_message['reasoning_content'] if 'reasoning_content' in chatcmpl_message else None
# deepseek的reasoner模型
if reasoning_content is not None:
chatcmpl_message['content'] = (
'<think>\n'
+ reasoning_content
+ '\n</think>\n'
+ chatcmpl_message['content']
)
chatcmpl_message['content'] = '<think>\n' + reasoning_content + '\n</think>\n' + chatcmpl_message['content']
message = llm_entities.Message(**chatcmpl_message)
@@ -124,10 +113,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
content = msg_dict.get('content')
if isinstance(content, list):
# 检查 content 列表中是否每个部分都是文本
if all(
isinstance(part, dict) and part.get('type') == 'text'
for part in content
):
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
# 将所有文本部分合并为一个字符串
msg_dict['content'] = '\n'.join(part['text'] for part in content)
req_messages.append(msg_dict)

View File

@@ -1,23 +1,17 @@
from __future__ import annotations
import asyncio
import typing
import json
import base64
from typing import AsyncGenerator
import openai
import openai.types.chat.chat_completion as chat_completion
import openai.types.chat.chat_completion_message_tool_call as chat_completion_message_tool_call
import httpx
import aiohttp
import async_lru
from .. import entities, errors, requester
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....utils import image
class ModelScopeChatCompletions(requester.LLMAPIRequester):
@@ -33,26 +27,22 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
self.requester_cfg = self.ap.provider_cfg.data['requester']['modelscope-chat-completions']
async def initialize(self):
self.client = openai.AsyncClient(
api_key="",
api_key='',
base_url=self.requester_cfg['base-url'],
timeout=self.requester_cfg['timeout'],
http_client=httpx.AsyncClient(
trust_env=True,
timeout=self.requester_cfg['timeout']
)
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),
)
async def _req(
self,
args: dict,
) -> chat_completion.ChatCompletion:
args["stream"] = True
args['stream'] = True
chunk = None
pending_content = ""
pending_content = ''
tool_calls = []
@@ -74,7 +64,7 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
break
else:
tool_calls.append(tool_call)
if chunk.choices[0].finish_reason is not None:
break
@@ -82,36 +72,41 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
for tc in tool_calls:
function = chat_completion_message_tool_call.Function(
name=tc.function.name,
arguments=tc.function.arguments
name=tc.function.name, arguments=tc.function.arguments
)
real_tool_calls.append(chat_completion_message_tool_call.ChatCompletionMessageToolCall(
id=tc.id,
function=function,
type="function"
))
return chat_completion.ChatCompletion(
id=chunk.id,
object="chat.completion",
created=chunk.created,
choices=[
chat_completion.Choice(
index=0,
message=chat_completion.ChatCompletionMessage(
role="assistant",
content=pending_content,
tool_calls=real_tool_calls if len(real_tool_calls) > 0 else None
),
finish_reason=chunk.choices[0].finish_reason if hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason is not None else 'stop',
logprobs=chunk.choices[0].logprobs,
real_tool_calls.append(
chat_completion_message_tool_call.ChatCompletionMessageToolCall(
id=tc.id, function=function, type='function'
)
],
model=chunk.model,
service_tier=chunk.service_tier if hasattr(chunk, 'service_tier') else None,
system_fingerprint=chunk.system_fingerprint if hasattr(chunk, 'system_fingerprint') else None,
usage=chunk.usage if hasattr(chunk, 'usage') else None
) if chunk else None
)
return (
chat_completion.ChatCompletion(
id=chunk.id,
object='chat.completion',
created=chunk.created,
choices=[
chat_completion.Choice(
index=0,
message=chat_completion.ChatCompletionMessage(
role='assistant',
content=pending_content,
tool_calls=real_tool_calls if len(real_tool_calls) > 0 else None,
),
finish_reason=chunk.choices[0].finish_reason
if hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason is not None
else 'stop',
logprobs=chunk.choices[0].logprobs,
)
],
model=chunk.model,
service_tier=chunk.service_tier if hasattr(chunk, 'service_tier') else None,
system_fingerprint=chunk.system_fingerprint if hasattr(chunk, 'system_fingerprint') else None,
usage=chunk.usage if hasattr(chunk, 'usage') else None,
)
if chunk
else None
)
return await self.client.chat.completions.create(**args)
async def _make_msg(
@@ -138,29 +133,27 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
args['model'] = use_model.name if use_model.model_name is None else use_model.model_name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args["tools"] = tools
args['tools'] = tools
# 设置此次请求中的messages
messages = req_messages.copy()
# 检查vision
for msg in messages:
if 'content' in msg and isinstance(msg["content"], list):
for me in msg["content"]:
if me["type"] == "image_base64":
me["image_url"] = {
"url": me["image_base64"]
}
me["type"] = "image_url"
del me["image_base64"]
if 'content' in msg and isinstance(msg['content'], list):
for me in msg['content']:
if me['type'] == 'image_base64':
me['image_url'] = {'url': me['image_base64']}
me['type'] = 'image_url'
del me['image_base64']
args["messages"] = messages
args['messages'] = messages
# 发送请求
resp = await self._req(args)
@@ -180,12 +173,12 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
for m in messages:
msg_dict = m.dict(exclude_none=True)
content = msg_dict.get("content")
content = msg_dict.get('content')
if isinstance(content, list):
# 检查 content 列表中是否每个部分都是文本
if all(isinstance(part, dict) and part.get("type") == "text" for part in content):
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
# 将所有文本部分合并为一个字符串
msg_dict["content"] = "\n".join(part["text"] for part in content)
msg_dict['content'] = '\n'.join(part['text'] for part in content)
req_messages.append(msg_dict)
try:
@@ -204,4 +197,4 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
except openai.RateLimitError as e:
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
except openai.APIError as e:
raise errors.RequesterError(f'请求错误: {e.message}')
raise errors.RequesterError(f'请求错误: {e.message}')

View File

@@ -61,13 +61,9 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
msg['content'] = '\n'.join(text_content)
msg['images'] = [url.split(',')[1] for url in image_urls]
if (
'tool_calls' in msg
): # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
for tool_call in msg['tool_calls']:
tool_call['function']['arguments'] = json.loads(
tool_call['function']['arguments']
)
tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments'])
args['messages'] = messages
args['tools'] = []
@@ -80,9 +76,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
message: llm_entities.Message = await self._make_msg(resp)
return message
async def _make_msg(
self, chat_completions: ollama.ChatResponse
) -> llm_entities.Message:
async def _make_msg(self, chat_completions: ollama.ChatResponse) -> llm_entities.Message:
message: ollama.Message = chat_completions.message
if message is None:
raise ValueError("chat_completions must contain a 'message' field")
@@ -122,10 +116,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
msg_dict: dict = m.dict(exclude_none=True)
content: Any = msg_dict.get('content')
if isinstance(content, list):
if all(
isinstance(part, dict) and part.get('type') == 'text'
for part in content
):
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
msg_dict['content'] = '\n'.join(part['text'] for part in content)
req_messages.append(msg_dict)
try:

View File

@@ -1,12 +1,11 @@
from __future__ import annotations
import openai
from . import chatcmpl, modelscopechatcmpl
from .. import requester
from . import chatcmpl
from ....core import app
class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions):
"""欧派云 ChatCompletion API 请求器"""
@@ -17,4 +16,4 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions):
def __init__(self, ap: app.Application):
self.ap = ap
self.requester_cfg = self.ap.provider_cfg.data['requester']['ppio-chat-completions']
self.requester_cfg = self.ap.provider_cfg.data['requester']['ppio-chat-completions']

View File

@@ -35,8 +35,6 @@ class RequestRunner(abc.ABC):
self.pipeline_config = pipeline_config
@abc.abstractmethod
async def run(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求"""
pass

View File

@@ -26,7 +26,9 @@ class DashScopeAPIRunner(runner.RequestRunner):
app_type: str # 应用类型
app_id: str # 应用ID
api_key: str # API Key
references_quote: str # 引用资料提示当展示回答来源功能开启时这个变量会作为引用资料名前的提示可在provider.json中配置
references_quote: (
str # 引用资料提示当展示回答来源功能开启时这个变量会作为引用资料名前的提示可在provider.json中配置
)
def __init__(self, ap: app.Application, pipeline_config: dict):
"""初始化"""
@@ -42,9 +44,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
# 初始化Dashscope 参数配置
self.app_id = self.pipeline_config['ai']['dashscope-app-api']['app-id']
self.api_key = self.pipeline_config['ai']['dashscope-app-api']['api-key']
self.references_quote = self.pipeline_config['ai']['dashscope-app-api'][
'references_quote'
]
self.references_quote = self.pipeline_config['ai']['dashscope-app-api']['references_quote']
def _replace_references(self, text, references_dict):
"""阿里云百炼平台的自定义应用支持资料引用,此函数可以将引用标签替换为参考资料"""
@@ -65,9 +65,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
# 使用 re.sub() 进行替换
return pattern.sub(replacement, text)
async def _preprocess_user_message(
self, query: core_entities.Query
) -> tuple[str, list[str]]:
async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]:
"""预处理用户消息,提取纯文本,阿里云提供的上传文件方法过于复杂,暂不支持上传文件(包括图片)"""
plain_text = ''
image_ids = []
@@ -91,9 +89,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
return plain_text, image_ids
async def _agent_messages(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def _agent_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""Dashscope 智能体对话请求"""
# 局部变量
@@ -151,9 +147,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
content=pending_content,
)
async def _workflow_messages(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""Dashscope 工作流对话请求"""
# 局部变量
@@ -216,9 +210,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
content=pending_content,
)
async def run(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行"""
if self.app_type == 'agent':
async for msg in self._agent_messages(query):

View File

@@ -26,10 +26,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
self.pipeline_config = pipeline_config
valid_app_types = ['chat', 'agent', 'workflow']
if (
self.pipeline_config['ai']['dify-service-api']['app-type']
not in valid_app_types
):
if self.pipeline_config['ai']['dify-service-api']['app-type'] not in valid_app_types:
raise errors.DifyAPIError(
f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}'
)
@@ -48,16 +45,10 @@ class DifyServiceAPIRunner(runner.RequestRunner):
):
return resp_text
if (
self.pipeline_config['ai']['dify-service-api']['thinking-convert']
== 'original'
):
if self.pipeline_config['ai']['dify-service-api']['thinking-convert'] == 'original':
return resp_text
if (
self.pipeline_config['ai']['dify-service-api']['thinking-convert']
== 'remove'
):
if self.pipeline_config['ai']['dify-service-api']['thinking-convert'] == 'remove':
return re.sub(
r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>.*?</details>',
'',
@@ -65,18 +56,13 @@ class DifyServiceAPIRunner(runner.RequestRunner):
flags=re.DOTALL,
)
if (
self.pipeline_config['ai']['dify-service-api']['thinking-convert']
== 'plain'
):
if self.pipeline_config['ai']['dify-service-api']['thinking-convert'] == 'plain':
pattern = r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>(.*?)</details>'
thinking_text = re.search(pattern, resp_text, flags=re.DOTALL)
content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL)
return f'<think>{thinking_text.group(1)}</think>\n{content_text}'
async def _preprocess_user_message(
self, query: core_entities.Query
) -> tuple[str, list[str]]:
async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]:
"""预处理用户消息,提取纯文本,并将图片上传到 Dify 服务
Returns:
@@ -90,9 +76,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
if ce.type == 'text':
plain_text += ce.text
elif ce.type == 'image_base64':
image_b64, image_format = await image.extract_b64_and_format(
ce.image_base64
)
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
file_bytes = base64.b64decode(image_b64)
file = ('img.png', file_bytes, f'image/{image_format}')
file_upload_resp = await self.dify_client.upload_file(
@@ -106,9 +90,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
return plain_text, image_ids
async def _chat_messages(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def _chat_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""调用聊天助手"""
cov_id = query.session.using_conversation.uuid or ''
@@ -151,9 +133,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
if chunk['data']['node_type'] == 'answer':
yield llm_entities.Message(
role='assistant',
content=self._try_convert_thinking(
chunk['data']['outputs']['answer']
),
content=self._try_convert_thinking(chunk['data']['outputs']['answer']),
)
elif mode == 'basic':
if chunk['event'] == 'message':
@@ -166,9 +146,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
basic_mode_pending_chunk = ''
if chunk is None:
raise errors.DifyAPIError(
'Dify API 没有返回任何响应请检查网络连接和API配置'
)
raise errors.DifyAPIError('Dify API 没有返回任何响应请检查网络连接和API配置')
query.session.using_conversation.uuid = chunk['conversation_id']
@@ -217,9 +195,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
pending_agent_message += chunk['answer']
else:
if pending_agent_message.strip() != '':
pending_agent_message = pending_agent_message.replace(
'</details>Action:', '</details>'
)
pending_agent_message = pending_agent_message.replace('</details>Action:', '</details>')
yield llm_entities.Message(
role='assistant',
content=self._try_convert_thinking(pending_agent_message),
@@ -227,9 +203,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
pending_agent_message = ''
if chunk['event'] == 'agent_thought':
if (
chunk['tool'] != '' and chunk['observation'] != ''
): # 工具调用结果,跳过
if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过
continue
if chunk['tool']:
@@ -258,23 +232,17 @@ class DifyServiceAPIRunner(runner.RequestRunner):
yield llm_entities.Message(
role='assistant',
content=[
llm_entities.ContentElement.from_image_url(image_url)
],
content=[llm_entities.ContentElement.from_image_url(image_url)],
)
if chunk['event'] == 'error':
raise errors.DifyAPIError('dify 服务错误: ' + chunk['message'])
if chunk is None:
raise errors.DifyAPIError(
'Dify API 没有返回任何响应请检查网络连接和API配置'
)
raise errors.DifyAPIError('Dify API 没有返回任何响应请检查网络连接和API配置')
query.session.using_conversation.uuid = chunk['conversation_id']
async def _workflow_messages(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""调用工作流"""
if not query.session.using_conversation.uuid:
@@ -315,10 +283,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
continue
if chunk['event'] == 'node_started':
if (
chunk['data']['node_type'] == 'start'
or chunk['data']['node_type'] == 'end'
):
if chunk['data']['node_type'] == 'start' or chunk['data']['node_type'] == 'end':
continue
msg = llm_entities.Message(
@@ -349,9 +314,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
yield msg
async def run(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求"""
if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat':
async for msg in self._chat_messages(query):

View File

@@ -12,15 +12,11 @@ from .. import entities as llm_entities
class LocalAgentRunner(runner.RequestRunner):
"""本地Agent请求运行器"""
async def run(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求"""
pending_tool_calls = []
req_messages = (
query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
)
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
# 首次请求
msg = await query.use_llm_model.requester.invoke_llm(
@@ -45,9 +41,7 @@ class LocalAgentRunner(runner.RequestRunner):
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
func_ret = await self.ap.tool_mgr.execute_func_call(query, func.name, parameters)
msg = llm_entities.Message(
role='tool',
@@ -60,9 +54,7 @@ class LocalAgentRunner(runner.RequestRunner):
req_messages.append(msg)
except Exception as e:
# 工具调用出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role='tool', content=f'err: {e}', tool_call_id=tool_call.id
)
err_msg = llm_entities.Message(role='tool', content=f'err: {e}', tool_call_id=tool_call.id)
yield err_msg

View File

@@ -23,10 +23,7 @@ class SessionManager:
async def get_session(self, query: core_entities.Query) -> core_entities.Session:
"""获取会话"""
for session in self.session_list:
if (
query.launcher_type == session.launcher_type
and query.launcher_id == session.launcher_id
):
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
return session
session_concurrency = self.ap.instance_config.data['concurrency']['session']

View File

@@ -45,9 +45,7 @@ class ToolLoader(abc.ABC):
pass
@abc.abstractmethod
async def invoke_tool(
self, query: core_entities.Query, name: str, parameters: dict
) -> typing.Any:
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
"""执行工具调用"""
pass

View File

@@ -43,15 +43,11 @@ class RuntimeMCPSession:
env=self.server_config['env'],
)
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
stdio, write = stdio_transport
self.session = await self.exit_stack.enter_async_context(
ClientSession(stdio, write)
)
self.session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
await self.session.initialize()
@@ -66,25 +62,19 @@ class RuntimeMCPSession:
sseio, write = sse_transport
self.session = await self.exit_stack.enter_async_context(
ClientSession(sseio, write)
)
self.session = await self.exit_stack.enter_async_context(ClientSession(sseio, write))
await self.session.initialize()
async def initialize(self):
self.ap.logger.debug(
f'初始化 MCP 会话: {self.server_name} {self.server_config}'
)
self.ap.logger.debug(f'初始化 MCP 会话: {self.server_name} {self.server_config}')
if self.server_config['mode'] == 'stdio':
await self._init_stdio_python_server()
elif self.server_config['mode'] == 'sse':
await self._init_sse_server()
else:
raise ValueError(
f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}'
)
raise ValueError(f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}')
tools = await self.session.list_tools()
@@ -132,9 +122,7 @@ class MCPLoader(loader.ToolLoader):
self._last_listed_functions = []
async def initialize(self):
for server_config in self.ap.instance_config.data.get('mcp', {}).get(
'servers', []
):
for server_config in self.ap.instance_config.data.get('mcp', {}).get('servers', []):
if not server_config['enable']:
continue
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
@@ -155,9 +143,7 @@ class MCPLoader(loader.ToolLoader):
async def has_tool(self, name: str) -> bool:
return name in [f.name for f in self._last_listed_functions]
async def invoke_tool(
self, query: core_entities.Query, name: str, parameters: dict
) -> typing.Any:
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
for server_name, session in self.sessions.items():
for function in session.functions:
if function.name == name:

View File

@@ -48,9 +48,7 @@ class PluginToolLoader(loader.ToolLoader):
return function, plugin.plugin_inst
return None, None
async def invoke_tool(
self, query: core_entities.Query, name: str, parameters: dict
) -> typing.Any:
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
try:
function, plugin = await self._get_function_and_plugin(name)
if function is None:

View File

@@ -28,9 +28,7 @@ class ToolManager:
await loader_inst.initialize()
self.loaders.append(loader_inst)
async def get_all_functions(
self, plugin_enabled: bool = None
) -> list[entities.LLMFunction]:
async def get_all_functions(self, plugin_enabled: bool = None) -> list[entities.LLMFunction]:
"""获取所有函数"""
all_functions: list[entities.LLMFunction] = []
@@ -39,9 +37,7 @@ class ToolManager:
return all_functions
async def generate_tools_for_openai(
self, use_funcs: list[entities.LLMFunction]
) -> list:
async def generate_tools_for_openai(self, use_funcs: list[entities.LLMFunction]) -> list:
"""生成函数列表"""
tools = []
@@ -58,9 +54,7 @@ class ToolManager:
return tools
async def generate_tools_for_anthropic(
self, use_funcs: list[entities.LLMFunction]
) -> list:
async def generate_tools_for_anthropic(self, use_funcs: list[entities.LLMFunction]) -> list:
"""为anthropic生成函数列表
e.g.
@@ -95,9 +89,7 @@ class ToolManager:
return tools
async def execute_func_call(
self, query: core_entities.Query, name: str, parameters: dict
) -> typing.Any:
async def execute_func_call(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
"""执行函数调用"""
for loader in self.loaders: