style: introduce ruff as linter and formatter (#1356)

* style: remove necessary imports

* style: fix F841

* style: fix F401

* style: fix F811

* style: fix E402

* style: fix E721

* style: fix E722

* style: fix E722

* style: fix F541

* style: ruff format

* style: all passed

* style: add ruff in deps

* style: more ignores in ruff.toml

* style: add pre-commit
This commit is contained in:
Junyan Qin (Chin)
2025-04-29 17:24:07 +08:00
committed by GitHub
parent 09e70d70e9
commit 209f16af76
240 changed files with 5307 additions and 4689 deletions

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import typing
import enum
import pydantic.v1 as pydantic
from pkg.provider import entities
@@ -32,7 +31,6 @@ class ImageURLContentObject(pydantic.BaseModel):
class ContentElement(pydantic.BaseModel):
type: str
"""内容类型"""
@@ -57,7 +55,7 @@ class ContentElement(pydantic.BaseModel):
@classmethod
def from_image_url(cls, image_url: str):
return cls(type='image_url', image_url=ImageURLContentObject(url=image_url))
@classmethod
def from_image_base64(cls, image_base64: str):
return cls(type='image_base64', image_base64=image_base64)
@@ -82,15 +80,19 @@ class Message(pydantic.BaseModel):
def readable_str(self) -> str:
if self.content is not None:
return str(self.role) + ": " + str(self.get_content_platform_message_chain())
return (
str(self.role) + ': ' + str(self.get_content_platform_message_chain())
)
elif self.tool_calls is not None:
return f'调用工具: {self.tool_calls[0].id}'
else:
return '未知消息'
def get_content_platform_message_chain(self, prefix_text: str="") -> platform_message.MessageChain | None:
def get_content_platform_message_chain(
self, prefix_text: str = ''
) -> platform_message.MessageChain | None:
"""将内容转换为平台消息 MessageChain 对象
Args:
prefix_text (str): 首个文字组件的前缀文本
"""
@@ -98,21 +100,22 @@ class Message(pydantic.BaseModel):
if self.content is None:
return None
elif isinstance(self.content, str):
return platform_message.MessageChain([platform_message.Plain(prefix_text+self.content)])
return platform_message.MessageChain(
[platform_message.Plain(prefix_text + self.content)]
)
elif isinstance(self.content, list):
mc = []
for ce in self.content:
if ce.type == 'text':
mc.append(platform_message.Plain(ce.text))
elif ce.type == 'image_url':
if ce.image_url.url.startswith("http"):
if ce.image_url.url.startswith('http'):
mc.append(platform_message.Image(url=ce.image_url.url))
else: # base64
b64_str = ce.image_url.url
if b64_str.startswith("data:"):
b64_str = b64_str.split(",")[1]
if b64_str.startswith('data:'):
b64_str = b64_str.split(',')[1]
mc.append(platform_message.Image(base64=b64_str))
@@ -120,7 +123,7 @@ class Message(pydantic.BaseModel):
if prefix_text:
for i, c in enumerate(mc):
if isinstance(c, platform_message.Plain):
mc[i] = platform_message.Plain(prefix_text+c.text)
mc[i] = platform_message.Plain(prefix_text + c.text)
break
else:
mc.insert(0, platform_message.Plain(prefix_text))

View File

@@ -2,4 +2,4 @@ class RequesterError(Exception):
"""Base class for all Requester errors."""
def __init__(self, message: str):
super().__init__("模型请求失败: "+message)
super().__init__('模型请求失败: ' + message)

View File

@@ -1,8 +1,6 @@
from __future__ import annotations
import typing
import sqlalchemy
import pydantic.v1 as pydantic
from . import entities, requester
from ...core import app
@@ -12,10 +10,8 @@ from ..tools import entities as tools_entities
from ...discover import engine
from . import token
from ...entity.persistence import model as persistence_model
from .requesters import bailianchatcmpl, chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat, giteeaichatcmpl, volcarkchatcmpl, xaichatcmpl, zhipuaichatcmpl, lmstudiochatcmpl, siliconflowchatcmpl, volcarkchatcmpl
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
FETCH_MODEL_LIST_URL = 'https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list'
class ModelManager:
@@ -36,7 +32,7 @@ class ModelManager:
requester_components: list[engine.Component]
requester_dict: dict[str, type[requester.LLMAPIRequester]] # cache
def __init__(self, ap: app.Application):
self.ap = ap
self.model_list = []
@@ -45,14 +41,18 @@ class ModelManager:
self.llm_models = []
self.requester_components = []
self.requester_dict = {}
async def initialize(self):
self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester')
self.requester_components = self.ap.discover.get_components_by_kind(
'LLMAPIRequester'
)
# forge requester class dict
requester_dict: dict[str, type[requester.LLMAPIRequester]] = {}
for component in self.requester_components:
requester_dict[component.metadata.name] = component.get_python_component_class()
requester_dict[component.metadata.name] = (
component.get_python_component_class()
)
self.requester_dict = requester_dict
@@ -74,18 +74,22 @@ class ModelManager:
# load models
for llm_model in llm_models:
await self.load_llm_model(llm_model)
async def load_llm_model(self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict):
async def load_llm_model(
self,
model_info: persistence_model.LLMModel
| sqlalchemy.Row[persistence_model.LLMModel]
| dict,
):
"""加载模型"""
if isinstance(model_info, sqlalchemy.Row):
model_info = persistence_model.LLMModel(**model_info._mapping)
elif isinstance(model_info, dict):
model_info = persistence_model.LLMModel(**model_info)
requester_inst = self.requester_dict[model_info.requester](
ap=self.ap,
config=model_info.requester_config
ap=self.ap, config=model_info.requester_config
)
await requester_inst.initialize()
@@ -96,24 +100,23 @@ class ModelManager:
name=model_info.uuid,
tokens=model_info.api_keys,
),
requester=requester_inst
requester=requester_inst,
)
self.llm_models.append(runtime_llm_model)
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: # deprecated
"""通过名称获取模型
"""
"""通过名称获取模型"""
for model in self.model_list:
if model.name == name:
return model
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
raise ValueError(f'无法确定模型 {name} 的信息,请在元数据中配置')
async def get_model_by_uuid(self, uuid: str) -> entities.LLMModelInfo:
"""通过uuid获取模型"""
for model in self.llm_models:
if model.model_entity.uuid == uuid:
return model
raise ValueError(f"model {uuid} not found")
raise ValueError(f'model {uuid} not found')
async def remove_llm_model(self, model_uuid: str):
"""移除模型"""
@@ -124,10 +127,7 @@ class ModelManager:
def get_available_requesters_info(self) -> list[dict]:
"""获取所有可用的请求器"""
return [
component.to_plain_dict()
for component in self.requester_components
]
return [component.to_plain_dict() for component in self.requester_components]
def get_available_requester_info_by_name(self, name: str) -> dict | None:
"""通过名称获取请求器信息"""
@@ -135,8 +135,10 @@ class ModelManager:
if component.metadata.name == name:
return component.to_plain_dict()
return None
def get_available_requester_manifest_by_name(self, name: str) -> engine.Component | None:
def get_available_requester_manifest_by_name(
self, name: str
) -> engine.Component | None:
"""通过名称获取请求器清单"""
for component in self.requester_components:
if component.metadata.name == name:
@@ -151,4 +153,3 @@ class ModelManager:
funcs: list[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
pass

View File

@@ -22,16 +22,21 @@ class RuntimeLLMModel:
requester: LLMAPIRequester
"""请求器实例"""
def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: LLMAPIRequester):
def __init__(
self,
model_entity: persistence_model.LLMModel,
token_mgr: token.TokenManager,
requester: LLMAPIRequester,
):
self.model_entity = model_entity
self.token_mgr = token_mgr
self.requester = requester
class LLMAPIRequester(metaclass=abc.ABCMeta):
"""LLM API请求器
"""
"""LLM API请求器"""
name: str = None
ap: app.Application
@@ -42,9 +47,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
def __init__(self, ap: app.Application, config: dict[str, typing.Any]):
self.ap = ap
self.requester_cfg = {
**self.default_config
}
self.requester_cfg = {**self.default_config}
self.requester_cfg.update(config)
async def initialize(self):

View File

@@ -2,16 +2,12 @@ from __future__ import annotations
import typing
import json
import traceback
import base64
import anthropic
import httpx
from ....core import app
from .. import entities, errors, requester
from .. import errors, requester
from .. import entities, errors
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@@ -29,7 +25,6 @@ class AnthropicMessages(requester.LLMAPIRequester):
}
async def initialize(self):
httpx_client = anthropic._base_client.AsyncHttpxClientWrapper(
base_url=self.requester_cfg['base_url'],
# cast to a valid type because mypy doesn't understand our type narrowing
@@ -40,7 +35,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
)
self.client = anthropic.AsyncAnthropic(
api_key="",
api_key='',
http_client=httpx_client,
)
@@ -55,7 +50,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
self.client.api_key = model.token_mgr.get_token()
args = extra_args.copy()
args["model"] = model.model_entity.name
args['model'] = model.model_entity.name
# 处理消息
@@ -63,14 +58,15 @@ class AnthropicMessages(requester.LLMAPIRequester):
system_role_message = None
for i, m in enumerate(messages):
if m.role == "system":
if m.role == 'system':
system_role_message = m
messages.pop(i)
break
if isinstance(system_role_message, llm_entities.Message) \
and isinstance(system_role_message.content, str):
if isinstance(system_role_message, llm_entities.Message) and isinstance(
system_role_message.content, str
):
args['system'] = system_role_message.content
req_messages = []
@@ -79,67 +75,64 @@ class AnthropicMessages(requester.LLMAPIRequester):
if m.role == 'tool':
tool_call_id = m.tool_call_id
req_messages.append({
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": tool_call_id,
"content": m.content
}
]
})
req_messages.append(
{
'role': 'user',
'content': [
{
'type': 'tool_result',
'tool_use_id': tool_call_id,
'content': m.content,
}
],
}
)
continue
msg_dict = m.dict(exclude_none=True)
if isinstance(m.content, str) and m.content.strip() != "":
msg_dict["content"] = [
{
"type": "text",
"text": m.content
}
]
if isinstance(m.content, str) and m.content.strip() != '':
msg_dict['content'] = [{'type': 'text', 'text': m.content}]
elif isinstance(m.content, list):
for i, ce in enumerate(m.content):
if ce.type == "image_base64":
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
if ce.type == 'image_base64':
image_b64, image_format = await image.extract_b64_and_format(
ce.image_base64
)
alter_image_ele = {
"type": "image",
"source": {
"type": "base64",
"media_type": f"image/{image_format}",
"data": image_b64
}
'type': 'image',
'source': {
'type': 'base64',
'media_type': f'image/{image_format}',
'data': image_b64,
},
}
msg_dict["content"][i] = alter_image_ele
msg_dict['content'][i] = alter_image_ele
if m.tool_calls:
for tool_call in m.tool_calls:
msg_dict["content"].append({
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function.name,
"input": json.loads(tool_call.function.arguments)
})
msg_dict['content'].append(
{
'type': 'tool_use',
'id': tool_call.id,
'name': tool_call.function.name,
'input': json.loads(tool_call.function.arguments),
}
)
del msg_dict["tool_calls"]
del msg_dict['tool_calls']
req_messages.append(msg_dict)
args["messages"] = req_messages
args['messages'] = req_messages
if funcs:
tools = await self.ap.tool_mgr.generate_tools_for_anthropic(funcs)
if tools:
args["tools"] = tools
args['tools'] = tools
try:
# print(json.dumps(args, indent=4, ensure_ascii=False))
@@ -149,23 +142,24 @@ class AnthropicMessages(requester.LLMAPIRequester):
'content': '',
'role': resp.role,
}
assert type(resp) is anthropic.types.message.Message
for block in resp.content:
if block.type == 'thinking':
args['content'] = '<think>' + block.thinking + '</think>\n' + args['content']
args['content'] = (
'<think>' + block.thinking + '</think>\n' + args['content']
)
elif block.type == 'text':
args['content'] += block.text
elif block.type == 'tool_use':
assert type(block) is anthropic.types.tool_use_block.ToolUseBlock
tool_call = llm_entities.ToolCall(
id=block.id,
type="function",
type='function',
function=llm_entities.FunctionCall(
name=block.name,
arguments=json.dumps(block.input)
)
name=block.name, arguments=json.dumps(block.input)
),
)
if 'tool_calls' not in args:
args['tool_calls'] = []

View File

@@ -4,8 +4,6 @@ import typing
import openai
from . import chatcmpl
from .. import requester
from ....core import app
class BailianChatCompletions(chatcmpl.OpenAIChatCompletions):

View File

@@ -2,22 +2,15 @@ from __future__ import annotations
import asyncio
import typing
import json
import base64
from typing import AsyncGenerator
import openai
import openai.types.chat.chat_completion as chat_completion
import openai.types.chat.chat_completion_message_tool_call as chat_completion_message_tool_call
import httpx
import aiohttp
import async_lru
from .. import entities, errors, requester
from ....core import entities as core_entities, app
from .. import errors, requester
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....utils import image
class OpenAIChatCompletions(requester.LLMAPIRequester):
@@ -26,18 +19,17 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
"base_url": "https://api.openai.com/v1",
"timeout": 120,
'base_url': 'https://api.openai.com/v1',
'timeout': 120,
}
async def initialize(self):
self.client = openai.AsyncClient(
api_key="",
base_url=self.requester_cfg["base_url"],
timeout=self.requester_cfg["timeout"],
api_key='',
base_url=self.requester_cfg['base_url'],
timeout=self.requester_cfg['timeout'],
http_client=httpx.AsyncClient(
trust_env=True, timeout=self.requester_cfg["timeout"]
trust_env=True, timeout=self.requester_cfg['timeout']
),
)
@@ -54,8 +46,8 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
chatcmpl_message = chat_completion.choices[0].message.model_dump()
# 确保 role 字段存在且不为 None
if "role" not in chatcmpl_message or chatcmpl_message["role"] is None:
chatcmpl_message["role"] = "assistant"
if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None:
chatcmpl_message['role'] = 'assistant'
message = llm_entities.Message(**chatcmpl_message)
@@ -72,27 +64,27 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
self.client.api_key = use_model.token_mgr.get_token()
args = extra_args.copy()
args["model"] = use_model.model_entity.name
args['model'] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args["tools"] = tools
args['tools'] = tools
# 设置此次请求中的messages
messages = req_messages.copy()
# 检查vision
for msg in messages:
if "content" in msg and isinstance(msg["content"], list):
for me in msg["content"]:
if me["type"] == "image_base64":
me["image_url"] = {"url": me["image_base64"]}
me["type"] = "image_url"
del me["image_base64"]
if 'content' in msg and isinstance(msg['content'], list):
for me in msg['content']:
if me['type'] == 'image_base64':
me['image_url'] = {'url': me['image_base64']}
me['type'] = 'image_url'
del me['image_base64']
args["messages"] = messages
args['messages'] = messages
# 发送请求
resp = await self._req(args)
@@ -113,15 +105,15 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
for m in messages:
msg_dict = m.dict(exclude_none=True)
content = msg_dict.get("content")
content = msg_dict.get('content')
if isinstance(content, list):
# 检查 content 列表中是否每个部分都是文本
if all(
isinstance(part, dict) and part.get("type") == "text"
isinstance(part, dict) and part.get('type') == 'text'
for part in content
):
# 将所有文本部分合并为一个字符串
msg_dict["content"] = "\n".join(part["text"] for part in content)
msg_dict['content'] = '\n'.join(part['text'] for part in content)
req_messages.append(msg_dict)
try:
@@ -133,17 +125,17 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
extra_args=extra_args,
)
except asyncio.TimeoutError:
raise errors.RequesterError("请求超时")
raise errors.RequesterError('请求超时')
except openai.BadRequestError as e:
if "context_length_exceeded" in e.message:
raise errors.RequesterError(f"上文过长,请重置会话: {e.message}")
if 'context_length_exceeded' in e.message:
raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
else:
raise errors.RequesterError(f"请求参数错误: {e.message}")
raise errors.RequesterError(f'请求参数错误: {e.message}')
except openai.AuthenticationError as e:
raise errors.RequesterError(f"无效的 api-key: {e.message}")
raise errors.RequesterError(f'无效的 api-key: {e.message}')
except openai.NotFoundError as e:
raise errors.RequesterError(f"请求路径错误: {e.message}")
raise errors.RequesterError(f'请求路径错误: {e.message}')
except openai.RateLimitError as e:
raise errors.RequesterError(f"请求过于频繁或余额不足: {e.message}")
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
except openai.APIError as e:
raise errors.RequesterError(f"请求错误: {e.message}")
raise errors.RequesterError(f'请求错误: {e.message}')

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
import typing
from . import chatcmpl
from .. import entities, errors, requester
from ....core import entities as core_entities, app
from .. import errors, requester
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@@ -28,23 +28,23 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
self.client.api_key = use_model.token_mgr.get_token()
args = extra_args.copy()
args["model"] = use_model.model_entity.name
args['model'] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args["tools"] = tools
args['tools'] = tools
# 设置此次请求中的messages
messages = req_messages
# deepseek 不支持多模态把content都转换成纯文字
for m in messages:
if 'content' in m and isinstance(m["content"], list):
m["content"] = " ".join([c["text"] for c in m["content"]])
if 'content' in m and isinstance(m['content'], list):
m['content'] = ' '.join([c['text'] for c in m['content']])
args["messages"] = messages
args['messages'] = messages
# 发送请求
resp = await self._req(args)
@@ -55,4 +55,4 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
# 处理请求结果
message = await self._make_msg(resp)
return message
return message

View File

@@ -1,17 +1,13 @@
from __future__ import annotations
import json
import asyncio
import aiohttp
import typing
from . import chatcmpl
from .. import entities, errors, requester
from ....core import app, entities as core_entities
from .. import requester
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from .. import entities as modelmgr_entities
class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
@@ -33,20 +29,20 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
self.client.api_key = use_model.token_mgr.get_token()
args = extra_args.copy()
args["model"] = use_model.model_entity.name
args['model'] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args["tools"] = tools
args['tools'] = tools
# gitee 不支持多模态把content都转换成纯文字
for m in req_messages:
if 'content' in m and isinstance(m["content"], list):
m["content"] = " ".join([c["text"] for c in m["content"]])
if 'content' in m and isinstance(m['content'], list):
m['content'] = ' '.join([c['text'] for c in m['content']])
args["messages"] = req_messages
args['messages'] = req_messages
resp = await self._req(args)

View File

@@ -4,8 +4,6 @@ import typing
import openai
from . import chatcmpl
from .. import requester
from ....core import app
class LmStudioChatCompletions(chatcmpl.OpenAIChatCompletions):

View File

@@ -2,11 +2,10 @@ from __future__ import annotations
import typing
from ....core import app
from . import chatcmpl
from .. import entities, errors, requester
from ....core import entities as core_entities, app
from .. import requester
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@@ -30,26 +29,26 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
self.client.api_key = use_model.token_mgr.get_token()
args = extra_args.copy()
args["model"] = use_model.model_entity.name
args['model'] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args["tools"] = tools
args['tools'] = tools
# 设置此次请求中的messages
messages = req_messages
# deepseek 不支持多模态把content都转换成纯文字
for m in messages:
if 'content' in m and isinstance(m["content"], list):
m["content"] = " ".join([c["text"] for c in m["content"]])
if 'content' in m and isinstance(m['content'], list):
m['content'] = ' '.join([c['text'] for c in m['content']])
# 删除空的
messages = [m for m in messages if m["content"].strip() != ""]
messages = [m for m in messages if m['content'].strip() != '']
args["messages"] = messages
args['messages'] = messages
# 发送请求
resp = await self._req(args)
@@ -57,4 +56,4 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
# 处理请求结果
message = await self._make_msg(resp)
return message
return message

View File

@@ -6,18 +6,15 @@ import typing
from typing import Union, Mapping, Any, AsyncIterator
import uuid
import json
import base64
import async_lru
import ollama
from .. import entities, errors, requester
from .. import errors, requester
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....core import app, entities as core_entities
from ....utils import image
from ....core import entities as core_entities
REQUESTER_NAME: str = "ollama-chat"
REQUESTER_NAME: str = 'ollama-chat'
class OllamaChatCompletions(requester.LLMAPIRequester):
@@ -26,13 +23,13 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
client: ollama.AsyncClient
default_config: dict[str, typing.Any] = {
"base_url": "http://127.0.0.1:11434",
"timeout": 120,
'base_url': 'http://127.0.0.1:11434',
'timeout': 120,
}
async def initialize(self):
os.environ["OLLAMA_HOST"] = self.requester_cfg["base_url"]
self.client = ollama.AsyncClient(timeout=self.requester_cfg["timeout"])
os.environ['OLLAMA_HOST'] = self.requester_cfg['base_url']
self.client = ollama.AsyncClient(timeout=self.requester_cfg['timeout'])
async def _req(
self,
@@ -49,35 +46,35 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
args = extra_args.copy()
args["model"] = use_model.model_entity.name
args['model'] = use_model.model_entity.name
messages: list[dict] = req_messages.copy()
for msg in messages:
if "content" in msg and isinstance(msg["content"], list):
if 'content' in msg and isinstance(msg['content'], list):
text_content: list = []
image_urls: list = []
for me in msg["content"]:
if me["type"] == "text":
text_content.append(me["text"])
elif me["type"] == "image_base64":
image_urls.append(me["image_base64"])
for me in msg['content']:
if me['type'] == 'text':
text_content.append(me['text'])
elif me['type'] == 'image_base64':
image_urls.append(me['image_base64'])
msg["content"] = "\n".join(text_content)
msg["images"] = [url.split(",")[1] for url in image_urls]
msg['content'] = '\n'.join(text_content)
msg['images'] = [url.split(',')[1] for url in image_urls]
if (
"tool_calls" in msg
'tool_calls' in msg
): # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
for tool_call in msg["tool_calls"]:
tool_call["function"]["arguments"] = json.loads(
tool_call["function"]["arguments"]
for tool_call in msg['tool_calls']:
tool_call['function']['arguments'] = json.loads(
tool_call['function']['arguments']
)
args["messages"] = messages
args['messages'] = messages
args["tools"] = []
args['tools'] = []
if user_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(user_funcs)
if tools:
args["tools"] = tools
args['tools'] = tools
resp = await self._req(args)
message: llm_entities.Message = await self._make_msg(resp)
@@ -93,7 +90,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
ret_msg: llm_entities.Message = None
if message.content is not None:
ret_msg = llm_entities.Message(role="assistant", content=message.content)
ret_msg = llm_entities.Message(role='assistant', content=message.content)
if message.tool_calls is not None and len(message.tool_calls) > 0:
tool_calls: list[llm_entities.ToolCall] = []
@@ -101,7 +98,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
tool_calls.append(
llm_entities.ToolCall(
id=uuid.uuid4().hex,
type="function",
type='function',
function=llm_entities.FunctionCall(
name=tool_call.function.name,
arguments=json.dumps(tool_call.function.arguments),
@@ -123,13 +120,13 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
req_messages: list = []
for m in messages:
msg_dict: dict = m.dict(exclude_none=True)
content: Any = msg_dict.get("content")
content: Any = msg_dict.get('content')
if isinstance(content, list):
if all(
isinstance(part, dict) and part.get("type") == "text"
isinstance(part, dict) and part.get('type') == 'text'
for part in content
):
msg_dict["content"] = "\n".join(part["text"] for part in content)
msg_dict['content'] = '\n'.join(part['text'] for part in content)
req_messages.append(msg_dict)
try:
return await self._closure(
@@ -140,4 +137,4 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
extra_args=extra_args,
)
except asyncio.TimeoutError:
raise errors.RequesterError("请求超时")
raise errors.RequesterError('请求超时')

View File

@@ -4,8 +4,6 @@ import typing
import openai
from . import chatcmpl
from .. import requester
from ....core import app
class SiliconFlowChatCompletions(chatcmpl.OpenAIChatCompletions):

View File

@@ -4,8 +4,6 @@ import typing
import openai
from . import chatcmpl
from .. import requester
from ....core import app
class VolcArkChatCompletions(chatcmpl.OpenAIChatCompletions):

View File

@@ -4,8 +4,6 @@ import typing
import openai
from . import chatcmpl
from .. import requester
from ....core import app
class XaiChatCompletions(chatcmpl.OpenAIChatCompletions):

View File

@@ -3,9 +3,7 @@ from __future__ import annotations
import typing
import openai
from ....core import app
from . import chatcmpl
from .. import requester
class ZhipuAIChatCompletions(chatcmpl.OpenAIChatCompletions):

View File

@@ -3,9 +3,8 @@ from __future__ import annotations
import typing
class TokenManager():
"""鉴权 Token 管理器
"""
class TokenManager:
"""鉴权 Token 管理器"""
name: str
@@ -20,6 +19,6 @@ class TokenManager():
def get_token(self) -> str:
return self.tokens[self.using_token_index]
def next_token(self):
self.using_token_index = (self.using_token_index + 1) % len(self.tokens)

View File

@@ -9,9 +9,10 @@ from . import entities as llm_entities
preregistered_runners: list[typing.Type[RequestRunner]] = []
def runner_class(name: str):
"""注册一个请求运行器
"""
"""注册一个请求运行器"""
def decorator(cls: typing.Type[RequestRunner]) -> typing.Type[RequestRunner]:
cls.name = name
preregistered_runners.append(cls)
@@ -21,8 +22,8 @@ def runner_class(name: str):
class RequestRunner(abc.ABC):
"""请求运行器
"""
"""请求运行器"""
name: str = None
ap: app.Application
@@ -34,7 +35,8 @@ class RequestRunner(abc.ABC):
self.pipeline_config = pipeline_config
@abc.abstractmethod
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求
"""
async def run(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求"""
pass

View File

@@ -1,8 +1,6 @@
from __future__ import annotations
import typing
import json
import base64
import re
import dashscope
@@ -10,7 +8,7 @@ import dashscope
from .. import runner
from ...core import app, entities as core_entities
from .. import entities as llm_entities
from ...utils import image
class DashscopeAPIError(Exception):
"""Dashscope API 请求失败"""
@@ -20,49 +18,49 @@ class DashscopeAPIError(Exception):
super().__init__(self.message)
@runner.runner_class("dashscope-app-api")
@runner.runner_class('dashscope-app-api')
class DashScopeAPIRunner(runner.RequestRunner):
"阿里云百炼DashsscopeAPI对话请求器"
# 运行器内部使用的配置
app_type: str # 应用类型
app_id: str # 应用ID
api_key: str # API Key
references_quote: str # 引用资料提示当展示回答来源功能开启时这个变量会作为引用资料名前的提示可在provider.json中配置
app_type: str # 应用类型
app_id: str # 应用ID
api_key: str # API Key
references_quote: str # 引用资料提示当展示回答来源功能开启时这个变量会作为引用资料名前的提示可在provider.json中配置
def __init__(self, ap: app.Application, pipeline_config: dict):
"""初始化"""
self.ap = ap
self.pipeline_config = pipeline_config
valid_app_types = ["agent", "workflow"]
self.app_type = self.pipeline_config["ai"]["dashscope-app-api"]["app-type"]
#检查配置文件中使用的应用类型是否支持
if (self.app_type not in valid_app_types):
raise DashscopeAPIError(
f"不支持的 Dashscope 应用类型: {self.app_type}"
)
#初始化Dashscope 参数配置
self.app_id = self.pipeline_config["ai"]["dashscope-app-api"]["app-id"]
self.api_key = self.pipeline_config["ai"]["dashscope-app-api"]["api-key"]
self.references_quote = self.pipeline_config["ai"]["dashscope-app-api"]["references_quote"]
valid_app_types = ['agent', 'workflow']
self.app_type = self.pipeline_config['ai']['dashscope-app-api']['app-type']
# 检查配置文件中使用的应用类型是否支持
if self.app_type not in valid_app_types:
raise DashscopeAPIError(f'不支持的 Dashscope 应用类型: {self.app_type}')
# 初始化Dashscope 参数配置
self.app_id = self.pipeline_config['ai']['dashscope-app-api']['app-id']
self.api_key = self.pipeline_config['ai']['dashscope-app-api']['api-key']
self.references_quote = self.pipeline_config['ai']['dashscope-app-api'][
'references_quote'
]
def _replace_references(self, text, references_dict):
"""阿里云百炼平台的自定义应用支持资料引用,此函数可以将引用标签替换为参考资料"""
# 匹配 <ref>[index_id]</ref> 形式的字符串
pattern = re.compile(r'<ref>\[(.*?)\]</ref>')
def replacement(match):
# 获取引用编号
ref_key = match.group(1)
ref_key = match.group(1)
if ref_key in references_dict:
# 如果有对应的参考资料按照provider.json中的reference_quote返回提示来自哪个参考资料文件
return f"({self.references_quote} {references_dict[ref_key]})"
return f'({self.references_quote} {references_dict[ref_key]})'
else:
# 如果没有对应的参考资料,保留原样
return match.group(0)
return match.group(0)
# 使用 re.sub() 进行替换
return pattern.sub(replacement, text)
@@ -71,14 +69,14 @@ class DashScopeAPIRunner(runner.RequestRunner):
self, query: core_entities.Query
) -> tuple[str, list[str]]:
"""预处理用户消息,提取纯文本,阿里云提供的上传文件方法过于复杂,暂不支持上传文件(包括图片)"""
plain_text = ""
plain_text = ''
image_ids = []
if isinstance(query.user_message.content, list):
for ce in query.user_message.content:
if ce.type == "text":
if ce.type == 'text':
plain_text += ce.text
# 暂时不支持上传图片,保留代码以便后续扩展
# elif ce.type == "image_base64":
# elif ce.type == "image_base64":
# image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
# file_bytes = base64.b64decode(image_b64)
# file = ("img.png", file_bytes, f"image/{image_format}")
@@ -92,147 +90,141 @@ class DashScopeAPIRunner(runner.RequestRunner):
plain_text = query.user_message.content
return plain_text, image_ids
async def _agent_messages(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""Dashscope 智能体对话请求"""
#局部变量
chunk = None # 流式传输的块
pending_content = "" # 待处理的Agent输出内容
references_dict = {} # 用于存储引用编号和对应的参考资料
plain_text = "" # 用户输入的纯文本信息
image_ids = [] # 用户输入的图片ID列表 (暂不支持)
# 局部变量
chunk = None # 流式传输的块
pending_content = '' # 待处理的Agent输出内容
references_dict = {} # 用于存储引用编号和对应的参考资料
plain_text = '' # 用户输入的纯文本信息
image_ids = [] # 用户输入的图片ID列表 (暂不支持)
plain_text, image_ids = await self._preprocess_user_message(query)
#发送对话请求
# 发送对话请求
response = dashscope.Application.call(
api_key=self.api_key, # 智能体应用的API Key
app_id=self.app_id, # 智能体应用的ID
prompt=plain_text, # 用户输入的文本信息
stream=True, # 流式输出
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
session_id=query.session.using_conversation.uuid, # 会话ID用于多轮对话
api_key=self.api_key, # 智能体应用的API Key
app_id=self.app_id, # 智能体应用的ID
prompt=plain_text, # 用户输入的文本信息
stream=True, # 流式输出
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
session_id=query.session.using_conversation.uuid, # 会话ID用于多轮对话
# rag_options={ # 主要用于文件交互,暂不支持
# "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个
# }
)
for chunk in response:
if chunk.get("status_code") != 200:
if chunk.get('status_code') != 200:
raise DashscopeAPIError(
f"Dashscope API 请求失败: status_code={chunk.get('status_code')} message={chunk.get('message')} request_id={chunk.get('request_id')} "
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
)
if not chunk:
continue
#获取流式传输的output
stream_output = chunk.get("output", {})
if stream_output.get("text") is not None:
pending_content += stream_output.get("text")
#保存当前会话的session_id用于下次对话的语境
query.session.using_conversation.uuid = stream_output.get("session_id")
#获取模型传出的参考资料列表
references_dict_list = stream_output.get("doc_references", [])
#从模型传出的参考资料信息中提取用于替换的字典
# 获取流式传输的output
stream_output = chunk.get('output', {})
if stream_output.get('text') is not None:
pending_content += stream_output.get('text')
# 保存当前会话的session_id用于下次对话的语境
query.session.using_conversation.uuid = stream_output.get('session_id')
# 获取模型传出的参考资料列表
references_dict_list = stream_output.get('doc_references', [])
# 从模型传出的参考资料信息中提取用于替换的字典
if references_dict_list is not None:
for doc in references_dict_list:
if doc.get("index_id") is not None:
references_dict[doc.get("index_id")] = doc.get("doc_name")
#将参考资料替换到文本中
if doc.get('index_id') is not None:
references_dict[doc.get('index_id')] = doc.get('doc_name')
# 将参考资料替换到文本中
pending_content = self._replace_references(pending_content, references_dict)
yield llm_entities.Message(
role="assistant",
role='assistant',
content=pending_content,
)
async def _workflow_messages(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""Dashscope 工作流对话请求"""
#局部变量
chunk = None # 流式传输的块
pending_content = "" # 待处理的Agent输出内容
references_dict = {} # 用于存储引用编号和对应的参考资料
plain_text = "" # 用户输入的纯文本信息
image_ids = [] # 用户输入的图片ID列表 (暂不支持)
# 局部变量
chunk = None # 流式传输的块
pending_content = '' # 待处理的Agent输出内容
references_dict = {} # 用于存储引用编号和对应的参考资料
plain_text = '' # 用户输入的纯文本信息
image_ids = [] # 用户输入的图片ID列表 (暂不支持)
plain_text, image_ids = await self._preprocess_user_message(query)
biz_params = {}
biz_params.update(query.variables)
#发送对话请求
# 发送对话请求
response = dashscope.Application.call(
api_key=self.api_key, # 智能体应用的API Key
app_id=self.app_id, # 智能体应用的ID
prompt=plain_text, # 用户输入的文本信息
stream=True, # 流式输出
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
session_id=query.session.using_conversation.uuid, # 会话ID用于多轮对话
biz_params=biz_params, # 工作流应用的自定义输入参数传递
api_key=self.api_key, # 智能体应用的API Key
app_id=self.app_id, # 智能体应用的ID
prompt=plain_text, # 用户输入的文本信息
stream=True, # 流式输出
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
session_id=query.session.using_conversation.uuid, # 会话ID用于多轮对话
biz_params=biz_params, # 工作流应用的自定义输入参数传递
# rag_options={ # 主要用于文件交互,暂不支持
# "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个
# }
)
#处理API返回的流式输出
# 处理API返回的流式输出
for chunk in response:
if chunk.get("status_code") != 200:
if chunk.get('status_code') != 200:
raise DashscopeAPIError(
f"Dashscope API 请求失败: status_code={chunk.get('status_code')} message={chunk.get('message')} request_id={chunk.get('request_id')} "
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
)
if not chunk:
continue
#获取流式传输的output
stream_output = chunk.get("output", {})
if stream_output.get("text") is not None:
pending_content += stream_output.get("text")
#保存当前会话的session_id用于下次对话的语境
query.session.using_conversation.uuid = stream_output.get("session_id")
#获取模型传出的参考资料列表
references_dict_list = stream_output.get("doc_references", [])
#从模型传出的参考资料信息中提取用于替换的字典
# 获取流式传输的output
stream_output = chunk.get('output', {})
if stream_output.get('text') is not None:
pending_content += stream_output.get('text')
# 保存当前会话的session_id用于下次对话的语境
query.session.using_conversation.uuid = stream_output.get('session_id')
# 获取模型传出的参考资料列表
references_dict_list = stream_output.get('doc_references', [])
# 从模型传出的参考资料信息中提取用于替换的字典
if references_dict_list is not None:
for doc in references_dict_list:
if doc.get("index_id") is not None:
references_dict[doc.get("index_id")] = doc.get("doc_name")
#将参考资料替换到文本中
if doc.get('index_id') is not None:
references_dict[doc.get('index_id')] = doc.get('doc_name')
# 将参考资料替换到文本中
pending_content = self._replace_references(pending_content, references_dict)
yield llm_entities.Message(
role="assistant",
role='assistant',
content=pending_content,
)
async def run(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行"""
if self.app_type == "agent":
if self.app_type == 'agent':
async for msg in self._agent_messages(query):
yield msg
elif self.app_type == "workflow":
elif self.app_type == 'workflow':
async for msg in self._workflow_messages(query):
yield msg
else:
raise DashscopeAPIError(
f"不支持的 Dashscope 应用类型: {self.app_type}"
)
raise DashscopeAPIError(f'不支持的 Dashscope 应用类型: {self.app_type}')

View File

@@ -5,9 +5,7 @@ import json
import uuid
import re
import base64
import datetime
import aiohttp
from .. import runner
from ...core import app, entities as core_entities
@@ -17,7 +15,7 @@ from ...utils import image
from libs.dify_service_api.v1 import client, errors
@runner.runner_class("dify-service-api")
@runner.runner_class('dify-service-api')
class DifyServiceAPIRunner(runner.RequestRunner):
"""Dify Service API 对话请求器"""
@@ -27,38 +25,54 @@ class DifyServiceAPIRunner(runner.RequestRunner):
self.ap = ap
self.pipeline_config = pipeline_config
valid_app_types = ["chat", "agent", "workflow"]
valid_app_types = ['chat', 'agent', 'workflow']
if (
self.pipeline_config["ai"]["dify-service-api"]["app-type"]
self.pipeline_config['ai']['dify-service-api']['app-type']
not in valid_app_types
):
raise errors.DifyAPIError(
f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}"
f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}'
)
api_key = self.pipeline_config["ai"]["dify-service-api"]["api-key"]
api_key = self.pipeline_config['ai']['dify-service-api']['api-key']
self.dify_client = client.AsyncDifyServiceClient(
api_key=api_key,
base_url=self.pipeline_config["ai"]["dify-service-api"]["base-url"],
base_url=self.pipeline_config['ai']['dify-service-api']['base-url'],
)
def _try_convert_thinking(self, resp_text: str) -> str:
"""尝试转换 Dify 的思考提示"""
if not resp_text.startswith("<details style=\"color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;\" open> <summary> Thinking... </summary>"):
if not resp_text.startswith(
'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>'
):
return resp_text
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "original":
if (
self.pipeline_config['ai']['dify-service-api']['thinking-convert']
== 'original'
):
return resp_text
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "remove":
return re.sub(r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>.*?</details>', '', resp_text, flags=re.DOTALL)
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "plain":
if (
self.pipeline_config['ai']['dify-service-api']['thinking-convert']
== 'remove'
):
return re.sub(
r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>.*?</details>',
'',
resp_text,
flags=re.DOTALL,
)
if (
self.pipeline_config['ai']['dify-service-api']['thinking-convert']
== 'plain'
):
pattern = r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>(.*?)</details>'
thinking_text = re.search(pattern, resp_text, flags=re.DOTALL)
content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL)
return f"<think>{thinking_text.group(1)}</think>\n{content_text}"
return f'<think>{thinking_text.group(1)}</think>\n{content_text}'
async def _preprocess_user_message(
self, query: core_entities.Query
@@ -68,22 +82,24 @@ class DifyServiceAPIRunner(runner.RequestRunner):
Returns:
tuple[str, list[str]]: 纯文本和图片的 Dify 服务图片 ID
"""
plain_text = ""
plain_text = ''
image_ids = []
if isinstance(query.user_message.content, list):
for ce in query.user_message.content:
if ce.type == "text":
if ce.type == 'text':
plain_text += ce.text
elif ce.type == "image_base64":
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
elif ce.type == 'image_base64':
image_b64, image_format = await image.extract_b64_and_format(
ce.image_base64
)
file_bytes = base64.b64decode(image_b64)
file = ("img.png", file_bytes, f"image/{image_format}")
file = ('img.png', file_bytes, f'image/{image_format}')
file_upload_resp = await self.dify_client.upload_file(
file,
f"{query.session.launcher_type.value}_{query.session.launcher_id}",
f'{query.session.launcher_type.value}_{query.session.launcher_id}',
)
image_id = file_upload_resp["id"]
image_id = file_upload_resp['id']
image_ids.append(image_id)
elif isinstance(query.user_message.content, str):
plain_text = query.user_message.content
@@ -94,116 +110,119 @@ class DifyServiceAPIRunner(runner.RequestRunner):
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""调用聊天助手"""
cov_id = query.session.using_conversation.uuid or ""
cov_id = query.session.using_conversation.uuid or ''
plain_text, image_ids = await self._preprocess_user_message(query)
files = [
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": image_id,
'type': 'image',
'transfer_method': 'local_file',
'upload_file_id': image_id,
}
for image_id in image_ids
]
mode = "basic" # 标记是基础编排还是工作流编排
mode = 'basic' # 标记是基础编排还是工作流编排
basic_mode_pending_chunk = ''
inputs = {}
inputs.update(query.variables)
async for chunk in self.dify_client.chat_messages(
inputs=inputs,
query=plain_text,
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
conversation_id=cov_id,
files=files,
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
timeout=self.pipeline_config['ai']['dify-service-api']['timeout'],
):
self.ap.logger.debug("dify-chat-chunk: " + str(chunk))
self.ap.logger.debug('dify-chat-chunk: ' + str(chunk))
if chunk['event'] == 'workflow_started':
mode = "workflow"
mode = 'workflow'
if mode == "workflow":
if mode == 'workflow':
if chunk['event'] == 'node_finished':
if chunk['data']['node_type'] == 'answer':
yield llm_entities.Message(
role="assistant",
content=self._try_convert_thinking(chunk['data']['outputs']['answer']),
role='assistant',
content=self._try_convert_thinking(
chunk['data']['outputs']['answer']
),
)
elif mode == "basic":
elif mode == 'basic':
if chunk['event'] == 'message':
basic_mode_pending_chunk += chunk['answer']
elif chunk['event'] == 'message_end':
yield llm_entities.Message(
role="assistant",
role='assistant',
content=self._try_convert_thinking(basic_mode_pending_chunk),
)
basic_mode_pending_chunk = ''
query.session.using_conversation.uuid = chunk["conversation_id"]
query.session.using_conversation.uuid = chunk['conversation_id']
async def _agent_chat_messages(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""调用聊天助手"""
cov_id = query.session.using_conversation.uuid or ""
cov_id = query.session.using_conversation.uuid or ''
plain_text, image_ids = await self._preprocess_user_message(query)
files = [
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": image_id,
'type': 'image',
'transfer_method': 'local_file',
'upload_file_id': image_id,
}
for image_id in image_ids
]
ignored_events = ["agent_message"]
ignored_events = ['agent_message']
inputs = {}
inputs.update(query.variables)
async for chunk in self.dify_client.chat_messages(
inputs=inputs,
query=plain_text,
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
response_mode="streaming",
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
response_mode='streaming',
conversation_id=cov_id,
files=files,
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
timeout=self.pipeline_config['ai']['dify-service-api']['timeout'],
):
self.ap.logger.debug("dify-agent-chunk: " + str(chunk))
self.ap.logger.debug('dify-agent-chunk: ' + str(chunk))
if chunk["event"] in ignored_events:
if chunk['event'] in ignored_events:
continue
if chunk["event"] == "agent_thought":
if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过
if chunk['event'] == 'agent_thought':
if (
chunk['tool'] != '' and chunk['observation'] != ''
): # 工具调用结果,跳过
continue
if chunk['thought'].strip() != '': # 文字回复内容
msg = llm_entities.Message(
role="assistant",
content=chunk["thought"],
role='assistant',
content=chunk['thought'],
)
yield msg
if chunk['tool']:
msg = llm_entities.Message(
role="assistant",
role='assistant',
tool_calls=[
llm_entities.ToolCall(
id=chunk['id'],
type="function",
type='function',
function=llm_entities.FunctionCall(
name=chunk["tool"],
name=chunk['tool'],
arguments=json.dumps({}),
),
)
@@ -211,9 +230,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
)
yield msg
if chunk['event'] == 'message_file':
if chunk['type'] == 'image' and chunk['belongs_to'] == 'assistant':
base_url = self.dify_client.base_url
if base_url.endswith('/v1'):
@@ -222,11 +239,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
image_url = base_url + chunk['url']
yield llm_entities.Message(
role="assistant",
role='assistant',
content=[llm_entities.ContentElement.from_image_url(image_url)],
)
query.session.using_conversation.uuid = chunk["conversation_id"]
query.session.using_conversation.uuid = chunk['conversation_id']
async def _workflow_messages(
self, query: core_entities.Query
@@ -235,58 +252,57 @@ class DifyServiceAPIRunner(runner.RequestRunner):
if not query.session.using_conversation.uuid:
query.session.using_conversation.uuid = str(uuid.uuid4())
query.variables["conversation_id"] = query.session.using_conversation.uuid
query.variables['conversation_id'] = query.session.using_conversation.uuid
plain_text, image_ids = await self._preprocess_user_message(query)
files = [
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": image_id,
'type': 'image',
'transfer_method': 'local_file',
'upload_file_id': image_id,
}
for image_id in image_ids
]
ignored_events = ["text_chunk", "workflow_started"]
ignored_events = ['text_chunk', 'workflow_started']
inputs = { # these variables are legacy variables, we need to keep them for compatibility
"langbot_user_message_text": plain_text,
"langbot_session_id": query.variables["session_id"],
"langbot_conversation_id": query.variables["conversation_id"],
"langbot_msg_create_time": query.variables["msg_create_time"],
'langbot_user_message_text': plain_text,
'langbot_session_id': query.variables['session_id'],
'langbot_conversation_id': query.variables['conversation_id'],
'langbot_msg_create_time': query.variables['msg_create_time'],
}
inputs.update(query.variables)
async for chunk in self.dify_client.workflow_run(
inputs=inputs,
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
files=files,
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
timeout=self.pipeline_config['ai']['dify-service-api']['timeout'],
):
self.ap.logger.debug("dify-workflow-chunk: " + str(chunk))
if chunk["event"] in ignored_events:
self.ap.logger.debug('dify-workflow-chunk: ' + str(chunk))
if chunk['event'] in ignored_events:
continue
if chunk["event"] == "node_started":
if chunk['event'] == 'node_started':
if (
chunk["data"]["node_type"] == "start"
or chunk["data"]["node_type"] == "end"
chunk['data']['node_type'] == 'start'
or chunk['data']['node_type'] == 'end'
):
continue
msg = llm_entities.Message(
role="assistant",
role='assistant',
content=None,
tool_calls=[
llm_entities.ToolCall(
id=chunk["data"]["node_id"],
type="function",
id=chunk['data']['node_id'],
type='function',
function=llm_entities.FunctionCall(
name=chunk["data"]["title"],
name=chunk['data']['title'],
arguments=json.dumps({}),
),
)
@@ -295,13 +311,13 @@ class DifyServiceAPIRunner(runner.RequestRunner):
yield msg
elif chunk["event"] == "workflow_finished":
elif chunk['event'] == 'workflow_finished':
if chunk['data']['error']:
raise errors.DifyAPIError(chunk['data']['error'])
msg = llm_entities.Message(
role="assistant",
content=chunk["data"]["outputs"]["summary"],
role='assistant',
content=chunk['data']['outputs']['summary'],
)
yield msg
@@ -310,16 +326,16 @@ class DifyServiceAPIRunner(runner.RequestRunner):
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求"""
if self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "chat":
if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat':
async for msg in self._chat_messages(query):
yield msg
elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "agent":
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'agent':
async for msg in self._agent_chat_messages(query):
yield msg
elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "workflow":
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'workflow':
async for msg in self._workflow_messages(query):
yield msg
else:
raise errors.DifyAPIError(
f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}"
f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}'
)

View File

@@ -4,24 +4,28 @@ import json
import typing
from .. import runner
from ...core import app, entities as core_entities
from ...core import entities as core_entities
from .. import entities as llm_entities
@runner.runner_class("local-agent")
@runner.runner_class('local-agent')
class LocalAgentRunner(runner.RequestRunner):
"""本地Agent请求运行器
"""
"""本地Agent请求运行器"""
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求
"""
async def run(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求"""
pending_tool_calls = []
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
req_messages = (
query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
)
# 首次请求
msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs)
msg = await query.use_llm_model.requester.invoke_llm(
query, query.use_llm_model, req_messages, query.use_funcs
)
yield msg
@@ -34,7 +38,7 @@ class LocalAgentRunner(runner.RequestRunner):
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
@@ -42,7 +46,9 @@ class LocalAgentRunner(runner.RequestRunner):
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
role='tool',
content=json.dumps(func_ret, ensure_ascii=False),
tool_call_id=tool_call.id,
)
yield msg
@@ -51,7 +57,7 @@ class LocalAgentRunner(runner.RequestRunner):
except Exception as e:
# 工具调用出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
role='tool', content=f'err: {e}', tool_call_id=tool_call.id
)
yield err_msg
@@ -59,7 +65,9 @@ class LocalAgentRunner(runner.RequestRunner):
req_messages.append(err_msg)
# 处理完所有调用,再次请求
msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs)
msg = await query.use_llm_model.requester.invoke_llm(
query, query.use_llm_model, req_messages, query.use_funcs
)
yield msg

View File

@@ -3,13 +3,11 @@ from __future__ import annotations
import asyncio
from ...core import app, entities as core_entities
from ...plugin import context as plugin_context
from ...provider import entities as provider_entities
class SessionManager:
"""会话管理器
"""
"""会话管理器"""
ap: app.Application
@@ -23,10 +21,12 @@ class SessionManager:
pass
async def get_session(self, query: core_entities.Query) -> core_entities.Session:
"""获取会话
"""
"""获取会话"""
for session in self.session_list:
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
if (
query.launcher_type == session.launcher_type
and query.launcher_id == session.launcher_id
):
return session
session_concurrency = self.ap.instance_config.data['concurrency']['session']
@@ -39,7 +39,12 @@ class SessionManager:
self.session_list.append(session)
return session
async def get_conversation(self, query: core_entities.Query, session: core_entities.Session, prompt_config: list[dict]) -> core_entities.Conversation:
async def get_conversation(
self,
query: core_entities.Query,
session: core_entities.Session,
prompt_config: list[dict],
) -> core_entities.Conversation:
"""获取对话或创建对话"""
if not session.conversations:
@@ -52,7 +57,7 @@ class SessionManager:
prompt_messages.append(provider_entities.Message(**prompt_message))
prompt = provider_entities.Prompt(
name="default",
name='default',
messages=prompt_messages,
)

View File

@@ -1,13 +1,9 @@
from __future__ import annotations
import abc
import typing
import asyncio
import pydantic.v1 as pydantic
from ...core import entities as core_entities
class LLMFunction(pydantic.BaseModel):
"""函数"""

View File

@@ -9,9 +9,10 @@ from . import entities as tools_entities
preregistered_loaders: list[typing.Type[ToolLoader]] = []
def loader_class(name: str):
"""注册一个工具加载器
"""
"""注册一个工具加载器"""
def decorator(cls: typing.Type[ToolLoader]) -> typing.Type[ToolLoader]:
cls.name = name
preregistered_loaders.append(cls)
@@ -22,7 +23,7 @@ def loader_class(name: str):
class ToolLoader(abc.ABC):
"""工具加载器"""
name: str = None
ap: app.Application
@@ -34,7 +35,7 @@ class ToolLoader(abc.ABC):
pass
@abc.abstractmethod
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]:
"""获取所有工具"""
pass
@@ -44,11 +45,13 @@ class ToolLoader(abc.ABC):
pass
@abc.abstractmethod
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
async def invoke_tool(
self, query: core_entities.Query, name: str, parameters: dict
) -> typing.Any:
"""执行工具调用"""
pass
@abc.abstractmethod
async def shutdown(self):
"""关闭工具"""
pass
pass

View File

@@ -30,7 +30,7 @@ class RuntimeMCPSession:
self.server_name = server_name
self.server_config = server_config
self.ap = ap
self.session = None
self.exit_stack = AsyncExitStack()
@@ -38,9 +38,9 @@ class RuntimeMCPSession:
async def _init_stdio_python_server(self):
server_params = StdioServerParameters(
command=self.server_config["command"],
args=self.server_config["args"],
env=self.server_config["env"],
command=self.server_config['command'],
args=self.server_config['args'],
env=self.server_config['env'],
)
stdio_transport = await self.exit_stack.enter_async_context(
@@ -58,12 +58,12 @@ class RuntimeMCPSession:
async def _init_sse_server(self):
sse_transport = await self.exit_stack.enter_async_context(
sse_client(
self.server_config["url"],
headers=self.server_config.get("headers", {}),
timeout=self.server_config.get("timeout", 10),
self.server_config['url'],
headers=self.server_config.get('headers', {}),
timeout=self.server_config.get('timeout', 10),
)
)
sseio, write = sse_transport
self.session = await self.exit_stack.enter_async_context(
@@ -73,18 +73,22 @@ class RuntimeMCPSession:
await self.session.initialize()
async def initialize(self):
self.ap.logger.debug(f"初始化 MCP 会话: {self.server_name} {self.server_config}")
self.ap.logger.debug(
f'初始化 MCP 会话: {self.server_name} {self.server_config}'
)
if self.server_config["mode"] == "stdio":
if self.server_config['mode'] == 'stdio':
await self._init_stdio_python_server()
elif self.server_config["mode"] == "sse":
elif self.server_config['mode'] == 'sse':
await self._init_sse_server()
else:
raise ValueError(f"无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}")
raise ValueError(
f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}'
)
tools = await self.session.list_tools()
self.ap.logger.debug(f"获取 MCP 工具: {tools}")
self.ap.logger.debug(f'获取 MCP 工具: {tools}')
for tool in tools.tools:
@@ -93,25 +97,28 @@ class RuntimeMCPSession:
if result.isError:
raise Exception(result.content[0].text)
return result.content[0].text
func.__name__ = tool.name
self.functions.append(tools_entities.LLMFunction(
name=tool.name,
human_desc=tool.description,
description=tool.description,
parameters=tool.inputSchema,
func=func,
))
self.functions.append(
tools_entities.LLMFunction(
name=tool.name,
human_desc=tool.description,
description=tool.description,
parameters=tool.inputSchema,
func=func,
)
)
async def shutdown(self):
"""关闭工具"""
await self.session._exit_stack.aclose()
@loader.loader_class("mcp")
@loader.loader_class('mcp')
class MCPLoader(loader.ToolLoader):
"""MCP 工具加载器。
在此加载器中管理所有与 MCP Server 的连接。
"""
@@ -125,16 +132,17 @@ class MCPLoader(loader.ToolLoader):
self._last_listed_functions = []
async def initialize(self):
for server_config in self.ap.instance_config.data.get("mcp", {}).get("servers", []):
if not server_config["enable"]:
for server_config in self.ap.instance_config.data.get('mcp', {}).get(
'servers', []
):
if not server_config['enable']:
continue
session = RuntimeMCPSession(server_config["name"], server_config, self.ap)
session = RuntimeMCPSession(server_config['name'], server_config, self.ap)
await session.initialize()
# self.ap.event_loop.create_task(session.initialize())
self.sessions[server_config["name"]] = session
self.sessions[server_config['name']] = session
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]:
all_functions = []
for session in self.sessions.values():
@@ -147,13 +155,15 @@ class MCPLoader(loader.ToolLoader):
async def has_tool(self, name: str) -> bool:
return name in [f.name for f in self._last_listed_functions]
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
async def invoke_tool(
self, query: core_entities.Query, name: str, parameters: dict
) -> typing.Any:
for server_name, session in self.sessions.items():
for function in session.functions:
if function.name == name:
return await function.func(query, **parameters)
raise ValueError(f"未找到工具: {name}")
raise ValueError(f'未找到工具: {name}')
async def shutdown(self):
"""关闭工具"""

View File

@@ -4,19 +4,18 @@ import typing
import traceback
from .. import loader, entities as tools_entities
from ....core import app, entities as core_entities
from ....core import entities as core_entities
from ....plugin import context as plugin_context
@loader.loader_class("plugin-tool-loader")
@loader.loader_class('plugin-tool-loader')
class PluginToolLoader(loader.ToolLoader):
"""插件工具加载器。
本加载器中不存储工具信息,仅负责从插件系统中获取工具信息。
"""
async def get_tools(self, enabled: bool=True) -> list[tools_entities.LLMFunction]:
async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]:
# 从插件系统获取工具(内容函数)
all_functions: list[tools_entities.LLMFunction] = []
@@ -49,23 +48,23 @@ class PluginToolLoader(loader.ToolLoader):
return function, plugin.plugin_inst
return None, None
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
async def invoke_tool(
self, query: core_entities.Query, name: str, parameters: dict
) -> typing.Any:
try:
function, plugin = await self._get_function_and_plugin(name)
if function is None:
return None
parameters = parameters.copy()
parameters = {"query": query, **parameters}
parameters = {'query': query, **parameters}
return await function.func(plugin, **parameters)
except Exception as e:
self.ap.logger.error(f"执行函数 {name} 时发生错误: {e}")
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
traceback.print_exc()
return f"error occurred when executing function {name}: {e}"
return f'error occurred when executing function {name}: {e}'
finally:
plugin = None
@@ -75,13 +74,12 @@ class PluginToolLoader(loader.ToolLoader):
break
if plugin is not None:
await self.ap.ctr_mgr.usage.post_function_record(
plugin={
"name": plugin.plugin_name,
"remote": plugin.plugin_repository,
"version": plugin.plugin_version,
"author": plugin.plugin_author,
'name': plugin.plugin_name,
'remote': plugin.plugin_repository,
'version': plugin.plugin_version,
'author': plugin.plugin_author,
},
function_name=function.name,
function_description=function.description,

View File

@@ -1,12 +1,13 @@
from __future__ import annotations
import typing
import traceback
from ...core import app, entities as core_entities
from . import entities, loader as tools_loader
from ...plugin import context as plugin_context
from .loaders import plugin, mcp
from ...utils import importutil
from . import loaders
importutil.import_modules_in_pkg(loaders)
class ToolManager:
@@ -22,13 +23,14 @@ class ToolManager:
self.loaders = []
async def initialize(self):
for loader_cls in tools_loader.preregistered_loaders:
loader_inst = loader_cls(self.ap)
await loader_inst.initialize()
self.loaders.append(loader_inst)
async def get_all_functions(self, plugin_enabled: bool=None) -> list[entities.LLMFunction]:
async def get_all_functions(
self, plugin_enabled: bool = None
) -> list[entities.LLMFunction]:
"""获取所有函数"""
all_functions: list[entities.LLMFunction] = []
@@ -37,17 +39,19 @@ class ToolManager:
return all_functions
async def generate_tools_for_openai(self, use_funcs: list[entities.LLMFunction]) -> list:
async def generate_tools_for_openai(
self, use_funcs: list[entities.LLMFunction]
) -> list:
"""生成函数列表"""
tools = []
for function in use_funcs:
function_schema = {
"type": "function",
"function": {
"name": function.name,
"description": function.description,
"parameters": function.parameters,
'type': 'function',
'function': {
'name': function.name,
'description': function.description,
'parameters': function.parameters,
},
}
tools.append(function_schema)
@@ -83,9 +87,9 @@ class ToolManager:
for function in use_funcs:
function_schema = {
"name": function.name,
"description": function.description,
"input_schema": function.parameters,
'name': function.name,
'description': function.description,
'input_schema': function.parameters,
}
tools.append(function_schema)
@@ -100,7 +104,7 @@ class ToolManager:
if await loader.has_tool(name):
return await loader.invoke_tool(query, name, parameters)
else:
raise ValueError(f"未找到工具: {name}")
raise ValueError(f'未找到工具: {name}')
async def shutdown(self):
"""关闭所有工具"""