From 47ff883fc73c652b5604048729e83ae0a8ce3b59 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sun, 3 Aug 2025 13:08:51 +0800 Subject: [PATCH] perf: ruff format & remove `stream` params in requester --- libs/dingtalk_api/api.py | 30 ++--- .../controller/groups/pipelines/webchat.py | 20 ++-- pkg/core/entities.py | 4 +- pkg/pipeline/cntfilter/cntfilter.py | 2 +- pkg/pipeline/process/handlers/chat.py | 4 +- pkg/pipeline/respback/respback.py | 4 - pkg/platform/adapter.py | 17 ++- pkg/platform/sources/dingtalk.py | 8 +- pkg/platform/sources/discord.py | 1 - pkg/platform/sources/qqbotpy.py | 2 +- pkg/platform/sources/telegram.py | 1 - pkg/platform/sources/webchat.py | 17 ++- pkg/platform/sources/wechatpad.py | 13 +- pkg/platform/types/message.py | 17 +-- pkg/provider/entities.py | 4 +- pkg/provider/modelmgr/requester.py | 17 ++- pkg/provider/modelmgr/requesters/chatcmpl.py | 78 ++++++------ .../modelmgr/requesters/giteeaichatcmpl.py | 99 +++++++--------- .../modelmgr/requesters/modelscopechatcmpl.py | 96 +++++++-------- .../modelmgr/requesters/ppiochatcmpl.py | 111 +++++++++--------- pkg/provider/runners/difysvapi.py | 4 +- pkg/provider/runners/localagent.py | 3 +- pkg/utils/image.py | 8 +- pkg/utils/importutil.py | 2 +- 24 files changed, 263 insertions(+), 299 deletions(-) diff --git a/libs/dingtalk_api/api.py b/libs/dingtalk_api/api.py index d1c7065f..3d483a3a 100644 --- a/libs/dingtalk_api/api.py +++ b/libs/dingtalk_api/api.py @@ -3,7 +3,6 @@ import json import time from typing import Callable import dingtalk_stream # type: ignore -from dingtalk_stream import AckMessage, ChatbotHandler, CallbackHandler, CallbackMessage, ChatbotMessage, AICardReplier from .EchoHandler import EchoTextHandler from .dingtalkevent import DingTalkEvent import httpx @@ -254,24 +253,23 @@ class DingTalkClient: await self.logger.error(f'failed to send proactive massage to group: {traceback.format_exc()}') raise Exception(f'failed to send proactive massage to group: {traceback.format_exc()}') - async def create_and_card(self, temp_card_id: str, incoming_message: dingtalk_stream.ChatbotMessage,quote_origin:bool=False): - content_key = "content" - card_data = {content_key: ""} + async def create_and_card( + self, temp_card_id: str, incoming_message: dingtalk_stream.ChatbotMessage, quote_origin: bool = False + ): + content_key = 'content' + card_data = {content_key: ''} - card_instance = dingtalk_stream.AICardReplier( - self.client, incoming_message - ) + card_instance = dingtalk_stream.AICardReplier(self.client, incoming_message) # print(card_instance) # 先投放卡片: https://open.dingtalk.com/document/orgapp/create-and-deliver-cards card_instance_id = await card_instance.async_create_and_deliver_card( - temp_card_id, card_data, + temp_card_id, + card_data, ) - return card_instance,card_instance_id + return card_instance, card_instance_id - async def send_card_message(self, - card_instance, - card_instance_id: str,content: str,is_final: bool): - content_key = "content" + async def send_card_message(self, card_instance, card_instance_id: str, content: str, is_final: bool): + content_key = 'content' try: await card_instance.async_streaming( card_instance_id, @@ -286,16 +284,12 @@ class DingTalkClient: await card_instance.async_streaming( card_instance_id, content_key=content_key, - content_value="", + content_value='', append=False, finished=is_final, failed=True, ) - - - - async def start(self): """启动 WebSocket 连接,监听消息""" await self.client.start() diff --git a/pkg/api/http/controller/groups/pipelines/webchat.py b/pkg/api/http/controller/groups/pipelines/webchat.py index 6dc7f85a..c094731b 100644 --- a/pkg/api/http/controller/groups/pipelines/webchat.py +++ b/pkg/api/http/controller/groups/pipelines/webchat.py @@ -14,8 +14,9 @@ class WebChatDebugRouterGroup(group.RouterGroup): async def stream_generator(generator): async for message in generator: - yield f"data: {json.dumps({'message': message})}\n\n" - yield "data: {\"type\": \"end\"}\n\n" + yield f'data: {json.dumps({"message": message})}\n\n' + yield 'data: {"type": "end"}\n\n' + try: data = await quart.request.get_json() session_type = data.get('session_type', 'person') @@ -34,18 +35,18 @@ class WebChatDebugRouterGroup(group.RouterGroup): return self.http_status(404, -1, 'WebChat adapter not found') if is_stream: - - generator = webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj, is_stream) - - return quart.Response( - stream_generator(generator), - mimetype='text/event-stream' + generator = webchat_adapter.send_webchat_message( + pipeline_uuid, session_type, message_chain_obj, is_stream ) + return quart.Response(stream_generator(generator), mimetype='text/event-stream') + else: # result = await webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj) result = None - async for message in webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj): + async for message in webchat_adapter.send_webchat_message( + pipeline_uuid, session_type, message_chain_obj + ): result = message if result is not None: return self.success( @@ -56,7 +57,6 @@ class WebChatDebugRouterGroup(group.RouterGroup): else: return self.http_status(400, -1, 'message is required') - except Exception as e: return self.http_status(500, -1, f'Internal server error: {str(e)}') diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 31514fa8..5f357d78 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -87,7 +87,9 @@ class Query(pydantic.BaseModel): """使用的函数,由前置处理器阶段设置""" resp_messages: ( - typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] | typing.Optional[list[llm_entities.MessageChunk]] + typing.Optional[list[llm_entities.Message]] + | typing.Optional[list[platform_message.MessageChain]] + | typing.Optional[list[llm_entities.MessageChunk]] ) = [] """由Process阶段生成的回复消息对象列表""" diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 0bbc5103..e035c1d0 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -67,7 +67,7 @@ class ContentFilterStage(stage.PipelineStage): if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg': return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) if not message.strip(): - return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) else: for filter in self.filter_chain: if filter_entities.EnableStage.PRE in filter.enable_stages: diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 0f802658..6c428473 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -81,9 +81,7 @@ class ChatMessageHandler(handler.MessageHandler): query.resp_message_chain.pop() query.resp_messages.append(result) - self.ap.logger.info( - f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}' - ) + self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}') if result.content is not None: text_length += len(result.content) diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index f4153218..c7824856 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -3,12 +3,10 @@ from __future__ import annotations import random import asyncio -from typing_inspection.typing_objects import is_final from ...platform.types import events as platform_events from ...platform.types import message as platform_message -from ...provider import entities as llm_entities from .. import stage, entities from ...core import entities as core_entities @@ -56,6 +54,4 @@ class SendResponseBackStage(stage.PipelineStage): quote_origin=quote_origin, ) - - return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py index e4369efb..3412be3c 100644 --- a/pkg/platform/adapter.py +++ b/pkg/platform/adapter.py @@ -25,7 +25,6 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): logger: EventLogger - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): """初始化适配器 @@ -80,12 +79,12 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): """ raise NotImplementedError - async def create_message_card(self, message_id:typing.Type[str,int], event:platform_events.MessageEvent) -> bool: + async def create_message_card(self, message_id: typing.Type[str, int], event: platform_events.MessageEvent) -> bool: """创建卡片消息 Args: message_id (str): 消息ID event (platform_events.MessageEvent): 消息源事件 - """ + """ return False async def is_muted(self, group_id: int) -> bool: @@ -94,8 +93,8 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): def register_listener( self, - event_type: typing.Type[platform_message.Event], - callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None], + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], ): """注册事件监听器 @@ -107,8 +106,8 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): def unregister_listener( self, - event_type: typing.Type[platform_message.Event], - callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None], + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], ): """注销事件监听器 @@ -167,7 +166,7 @@ class EventConverter: """事件转换器基类""" @staticmethod - def yiri2target(event: typing.Type[platform_message.Event]): + def yiri2target(event: typing.Type[platform_events.Event]): """将源平台事件转换为目标平台事件 Args: @@ -179,7 +178,7 @@ class EventConverter: raise NotImplementedError @staticmethod - def target2yiri(event: typing.Any) -> platform_message.Event: + def target2yiri(event: typing.Any) -> platform_events.Event: """将目标平台事件的调用参数转换为源平台的事件参数对象 Args: diff --git a/pkg/platform/sources/dingtalk.py b/pkg/platform/sources/dingtalk.py index 9f834f2a..8bd6e187 100644 --- a/pkg/platform/sources/dingtalk.py +++ b/pkg/platform/sources/dingtalk.py @@ -149,10 +149,10 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): quote_origin: bool = False, is_final: bool = False, ): - event = await DingTalkEventConverter.yiri2target( - message_source, - ) - incoming_message = event.incoming_message + # event = await DingTalkEventConverter.yiri2target( + # message_source, + # ) + # incoming_message = event.incoming_message # msg_id = incoming_message.message_id diff --git a/pkg/platform/sources/discord.py b/pkg/platform/sources/discord.py index c279e714..da32c7ac 100644 --- a/pkg/platform/sources/discord.py +++ b/pkg/platform/sources/discord.py @@ -8,7 +8,6 @@ import base64 import uuid import os import datetime -import io import asyncio from enum import Enum diff --git a/pkg/platform/sources/qqbotpy.py b/pkg/platform/sources/qqbotpy.py index 39c8dc8a..d4a4d526 100644 --- a/pkg/platform/sources/qqbotpy.py +++ b/pkg/platform/sources/qqbotpy.py @@ -501,7 +501,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): for event_handler in event_handler_mapping[event_type]: setattr(self.bot, event_handler, wrapper) except Exception as e: - self.logger.error(f"Error in qqbotpy callback: {traceback.format_exc()}") + self.logger.error(f'Error in qqbotpy callback: {traceback.format_exc()}') raise e def unregister_listener( diff --git a/pkg/platform/sources/telegram.py b/pkg/platform/sources/telegram.py index 22ef63e8..d39bf23d 100644 --- a/pkg/platform/sources/telegram.py +++ b/pkg/platform/sources/telegram.py @@ -1,6 +1,5 @@ from __future__ import annotations -import time import telegram import telegram.ext diff --git a/pkg/platform/sources/webchat.py b/pkg/platform/sources/webchat.py index f7f3d964..fce28bc2 100644 --- a/pkg/platform/sources/webchat.py +++ b/pkg/platform/sources/webchat.py @@ -133,7 +133,11 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): ) # notify waiter - session = (self.webchat_group_session if isinstance(message_source, platform_events.GroupMessage) else self.webchat_person_session) + session = ( + self.webchat_group_session + if isinstance(message_source, platform_events.GroupMessage) + else self.webchat_person_session + ) if message_source.message_chain.message_id not in session.resp_waiters: # session.resp_waiters[message_source.message_chain.message_id] = asyncio.Queue() queue = session.resp_queues[message_source.message_chain.message_id] @@ -147,10 +151,8 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): # print(message_data) await queue.put(message_data) - - return message_data.model_dump() - + async def is_stream_output_supported(self) -> bool: return self.is_stream @@ -186,7 +188,10 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): await self.logger.info('WebChat调试适配器正在停止') async def send_webchat_message( - self, pipeline_uuid: str, session_type: str, message_chain_obj: typing.List[dict], + self, + pipeline_uuid: str, + session_type: str, + message_chain_obj: typing.List[dict], is_stream: bool = False, ) -> dict: self.is_stream = is_stream @@ -202,7 +207,7 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): if is_stream: use_session.resp_queues[message_id] = asyncio.Queue() - logger.debug(f"Initialized queue for message_id: {message_id}") + logger.debug(f'Initialized queue for message_id: {message_id}') use_session.get_message_list(pipeline_uuid).append( WebChatMessage( diff --git a/pkg/platform/sources/wechatpad.py b/pkg/platform/sources/wechatpad.py index 9bbb471d..895e77fb 100644 --- a/pkg/platform/sources/wechatpad.py +++ b/pkg/platform/sources/wechatpad.py @@ -241,8 +241,8 @@ class WeChatPadMessageConverter(adapter.MessageConverter): # self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode')) appmsg_data = xml_data.find('.//appmsg') quote_data = '' # 引用原文 - quote_id = None # 引用消息的原发送者 - tousername = None # 接收方: 所属微信的wxid + # quote_id = None # 引用消息的原发送者 + # tousername = None # 接收方: 所属微信的wxid user_data = '' # 用户消息 sender_id = xml_data.findtext('.//fromusername') # 发送方:单聊用户/群member @@ -250,13 +250,10 @@ class WeChatPadMessageConverter(adapter.MessageConverter): if appmsg_data: user_data = appmsg_data.findtext('.//title') or '' quote_data = appmsg_data.find('.//refermsg').findtext('.//content') - quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') + # quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg_data, encoding='unicode'))) - if message: - tousername = message['to_user_name']['str'] - - _ = quote_id - _ = tousername + # if message: + # tousername = message['to_user_name']['str'] if quote_data: quote_data_message_list = platform_message.MessageChain() diff --git a/pkg/platform/types/message.py b/pkg/platform/types/message.py index 7dad4145..ecd7cc96 100644 --- a/pkg/platform/types/message.py +++ b/pkg/platform/types/message.py @@ -812,12 +812,14 @@ class File(MessageComponent): def __str__(self): return f'[文件]{self.name}' + class Face(MessageComponent): """系统表情 此处将超级表情骰子/划拳,一同归类于face 当face_type为rps(划拳)时 face_id 对应的是手势 当face_type为dice(骰子)时 face_id 对应的是点数 """ + type: str = 'Face' """表情类型""" face_type: str = 'face' @@ -834,15 +836,15 @@ class Face(MessageComponent): elif self.face_type == 'rps': return f'[表情]{self.face_name}({self.rps_data(self.face_id)})' - - def rps_data(self,face_id): - rps_dict ={ - 1 : "布", - 2 : "剪刀", - 3 : "石头", + def rps_data(self, face_id): + rps_dict = { + 1: '布', + 2: '剪刀', + 3: '石头', } return rps_dict[face_id] + # ================ 个人微信专用组件 ================ @@ -971,5 +973,6 @@ class WeChatFile(MessageComponent): """文件地址""" file_base64: str = '' """base64""" + def __str__(self): - return f'[文件]{self.file_name}' \ No newline at end of file + return f'[文件]{self.file_name}' diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index df2b5487..ff1e4526 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -127,6 +127,7 @@ class Message(pydantic.BaseModel): class MessageChunk(pydantic.BaseModel): """消息""" + resp_message_id: typing.Optional[str] = None """消息id""" @@ -148,7 +149,7 @@ class MessageChunk(pydantic.BaseModel): tool_call_id: typing.Optional[str] = None # tool_calls: typing.Optional[list[ToolCallChunk]] = None - + is_final: bool = False def readable_str(self) -> str: @@ -210,6 +211,7 @@ class ToolCallChunk(pydantic.BaseModel): function: FunctionCall """函数调用""" + class Prompt(pydantic.BaseModel): """供AI使用的Prompt""" diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index fa4a9ff8..d28783b9 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -94,19 +94,18 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta): extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. Returns: - llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk]: 返回消息对象 + llm_entities.Message: 返回消息对象 """ pass @abc.abstractmethod async def invoke_llm_stream( - self, - query: core_entities.Query, - model: RuntimeLLMModel, - messages: typing.List[llm_entities.Message], - funcs: typing.List[tools_entities.LLMFunction] = None, - stream: bool = False, - extra_args: dict[str, typing.Any] = {}, + self, + query: core_entities.Query, + model: RuntimeLLMModel, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.MessageChunk: """调用API @@ -117,7 +116,7 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta): extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. Returns: - llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk]: 返回消息对象 + typing.AsyncGenerator[llm_entities.MessageChunk]: 返回消息对象 """ pass diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index d5c3b90a..4fcce481 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -8,7 +8,7 @@ import openai.types.chat.chat_completion as chat_completion import httpx from .. import errors, requester -from ....core import entities as core_entities, app +from ....core import entities as core_entities from ... import entities as llm_entities from ...tools import entities as tools_entities @@ -129,12 +129,10 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): req_messages: list[dict], use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, - stream: bool = False, extra_args: dict[str, typing.Any] = {}, - ) ->llm_entities.MessageChunk: + ) -> llm_entities.MessageChunk: self.client.api_key = use_model.token_mgr.get_token() - args = {} args['model'] = use_model.model_entity.name @@ -158,43 +156,42 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): args['messages'] = messages - if stream: - current_content = '' - args['stream'] = True - chunk_idx = 0 - self.is_content = False - tool_calls_map: dict[str, llm_entities.ToolCall] = {} - pipeline_config = query.pipeline_config - async for chunk in self._req_stream(args, extra_body=extra_args): - # 处理流式消息 - delta_message = await self._make_msg_chunk(pipeline_config, chunk, chunk_idx) - if delta_message.content: - current_content += delta_message.content - delta_message.content = current_content - # delta_message.all_content = current_content - if delta_message.tool_calls: - for tool_call in delta_message.tool_calls: - if tool_call.id not in tool_calls_map: - tool_calls_map[tool_call.id] = llm_entities.ToolCall( - id=tool_call.id, - type=tool_call.type, - function=llm_entities.FunctionCall( - name=tool_call.function.name if tool_call.function else '', arguments='' - ), - ) - if tool_call.function and tool_call.function.arguments: - # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 - tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + current_content = '' + args['stream'] = True + chunk_idx = 0 + self.is_content = False + tool_calls_map: dict[str, llm_entities.ToolCall] = {} + pipeline_config = query.pipeline_config + async for chunk in self._req_stream(args, extra_body=extra_args): + # 处理流式消息 + delta_message = await self._make_msg_chunk(pipeline_config, chunk, chunk_idx) + if delta_message.content: + current_content += delta_message.content + delta_message.content = current_content + # delta_message.all_content = current_content + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + if tool_call.id not in tool_calls_map: + tool_calls_map[tool_call.id] = llm_entities.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=llm_entities.FunctionCall( + name=tool_call.function.name if tool_call.function else '', arguments='' + ), + ) + if tool_call.function and tool_call.function.arguments: + # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 + tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments - chunk_idx += 1 - chunk_choices = getattr(chunk, 'choices', None) - if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): - delta_message.is_final = True - delta_message.content = current_content + chunk_idx += 1 + chunk_choices = getattr(chunk, 'choices', None) + if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): + delta_message.is_final = True + delta_message.content = current_content - if chunk_idx % 64 == 0 or delta_message.is_final: - yield delta_message - # return + if chunk_idx % 64 == 0 or delta_message.is_final: + yield delta_message + # return async def _closure( self, @@ -202,7 +199,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): req_messages: list[dict], use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, - stream: bool = False, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() @@ -317,7 +313,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): model: requester.RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, - stream: bool = False, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.MessageChunk: req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 @@ -337,7 +332,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): req_messages=req_messages, use_model=model, use_funcs=funcs, - stream=stream, extra_args=extra_args, ): yield item diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py index 2a618c9f..1c19a534 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py @@ -12,7 +12,6 @@ import re import openai.types.chat.chat_completion as chat_completion - class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): """Gitee AI ChatCompletions API 请求器""" @@ -20,7 +19,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): 'base_url': 'https://ai.gitee.com/v1', 'timeout': 120, } - is_think:bool = False + is_think: bool = False async def _closure( self, @@ -52,15 +51,14 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): pipeline_config = query.pipeline_config - message = await self._make_msg(resp,pipeline_config) + message = await self._make_msg(resp, pipeline_config) return message - async def _make_msg( - self, - chat_completion: chat_completion.ChatCompletion, - pipeline_config: dict[str, typing.Any] = {'trigger': {'misc': {'remove_think': False}}}, + self, + chat_completion: chat_completion.ChatCompletion, + pipeline_config: dict[str, typing.Any] = {'trigger': {'misc': {'remove_think': False}}}, ) -> llm_entities.Message: chatcmpl_message = chat_completion.choices[0].message.model_dump() # print(chatcmpl_message.keys(), chatcmpl_message.values()) @@ -73,23 +71,25 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): # deepseek的reasoner模型 if pipeline_config['trigger'].get('misc', '').get('remove_think'): - chatcmpl_message['content'] = re.sub(r'.*?', '', chatcmpl_message['content'], flags=re.DOTALL) + chatcmpl_message['content'] = re.sub( + r'.*?', '', chatcmpl_message['content'], flags=re.DOTALL + ) else: if reasoning_content is not None: - chatcmpl_message['content'] = '\n' + reasoning_content + '\n\n' + chatcmpl_message['content'] + chatcmpl_message['content'] = ( + '\n' + reasoning_content + '\n\n' + chatcmpl_message['content'] + ) message = llm_entities.Message(**chatcmpl_message) return message - async def _make_msg_chunk( self, pipeline_config: dict[str, typing.Any], chat_completion: chat_completion.ChatCompletion, idx: int, ) -> llm_entities.MessageChunk: - # 处理流式chunk和完整响应的差异 # print(chat_completion.choices[0]) if hasattr(chat_completion, 'choices'): @@ -104,7 +104,6 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): if 'role' not in delta or delta['role'] is None: delta['role'] = 'assistant' - reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None delta['content'] = '' if delta['content'] is None else delta['content'] @@ -115,7 +114,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): if delta['content'] == '': self.is_think = True delta['content'] = '' - if delta['content'] == rf'': + if delta['content'] == r'': self.is_think = False delta['content'] = '' if not self.is_think: @@ -126,7 +125,6 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): if reasoning_content is not None: delta['content'] += reasoning_content - message = llm_entities.MessageChunk(**delta) return message @@ -137,7 +135,6 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): req_messages: list[dict], use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, - stream: bool = False, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: self.client.api_key = use_model.token_mgr.get_token() @@ -165,44 +162,38 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): args['messages'] = messages - if stream: - current_content = '' - args["stream"] = True - chunk_idx = 0 - self.is_content = False - tool_calls_map: dict[str, llm_entities.ToolCall] = {} - pipeline_config = query.pipeline_config - async for chunk in self._req_stream(args, extra_body=extra_args): - # 处理流式消息 - delta_message = await self._make_msg_chunk(pipeline_config,chunk,chunk_idx) - if delta_message.content: - current_content += delta_message.content - delta_message.content = current_content - # delta_message.all_content = current_content - if delta_message.tool_calls: - for tool_call in delta_message.tool_calls: - if tool_call.id not in tool_calls_map: - tool_calls_map[tool_call.id] = llm_entities.ToolCall( - id=tool_call.id, - type=tool_call.type, - function=llm_entities.FunctionCall( - name=tool_call.function.name if tool_call.function else '', - arguments='' - ), - ) - if tool_call.function and tool_call.function.arguments: - # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 - tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments - - - chunk_idx += 1 - chunk_choices = getattr(chunk, 'choices', None) - if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): - delta_message.is_final = True - delta_message.content = current_content - - if chunk_idx % 64 == 0 or delta_message.is_final: - - yield delta_message + current_content = '' + args['stream'] = True + chunk_idx = 0 + self.is_content = False + tool_calls_map: dict[str, llm_entities.ToolCall] = {} + pipeline_config = query.pipeline_config + async for chunk in self._req_stream(args, extra_body=extra_args): + # 处理流式消息 + delta_message = await self._make_msg_chunk(pipeline_config, chunk, chunk_idx) + if delta_message.content: + current_content += delta_message.content + delta_message.content = current_content + # delta_message.all_content = current_content + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + if tool_call.id not in tool_calls_map: + tool_calls_map[tool_call.id] = llm_entities.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=llm_entities.FunctionCall( + name=tool_call.function.name if tool_call.function else '', arguments='' + ), + ) + if tool_call.function and tool_call.function.arguments: + # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 + tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + chunk_idx += 1 + chunk_choices = getattr(chunk, 'choices', None) + if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): + delta_message.is_final = True + delta_message.content = current_content + if chunk_idx % 64 == 0 or delta_message.is_final: + yield delta_message diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py index c1888a5e..97201e47 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py @@ -165,11 +165,10 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): return message async def _req_stream( - self, - args: dict, - extra_body: dict = {}, + self, + args: dict, + extra_body: dict = {}, ) -> chat_completion.ChatCompletion: - async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body): yield chunk @@ -179,7 +178,6 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): chat_completion: chat_completion.ChatCompletion, idx: int, ) -> llm_entities.MessageChunk: - # 处理流式chunk和完整响应的差异 # print(chat_completion.choices[0]) if hasattr(chat_completion, 'choices'): @@ -195,7 +193,6 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): if 'role' not in delta or delta['role'] is None: delta['role'] = 'assistant' - reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None delta['content'] = '' if delta['content'] is None else delta['content'] @@ -203,13 +200,13 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): # deepseek的reasoner模型 if pipeline_config['trigger'].get('misc', '').get('remove_think'): - if reasoning_content is not None : + if reasoning_content is not None: pass else: delta['content'] = delta['content'] else: if reasoning_content is not None and idx == 0: - delta['content'] += f'\n{reasoning_content}' + delta['content'] += f'\n{reasoning_content}' elif reasoning_content is None: if self.is_content: delta['content'] = delta['content'] @@ -219,7 +216,6 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): else: delta['content'] += reasoning_content - message = llm_entities.MessageChunk(**delta) return message @@ -230,7 +226,6 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): req_messages: list[dict], use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, - stream: bool = False, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: self.client.api_key = use_model.token_mgr.get_token() @@ -258,48 +253,42 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): args['messages'] = messages - if stream: - current_content = '' - args["stream"] = True - chunk_idx = 0 - self.is_content = False - tool_calls_map: dict[str, llm_entities.ToolCall] = {} - pipeline_config = query.pipeline_config - async for chunk in self._req_stream(args, extra_body=extra_args): - # 处理流式消息 - delta_message = await self._make_msg_chunk(pipeline_config,chunk,chunk_idx) - if delta_message.content: - current_content += delta_message.content - delta_message.content = current_content - # delta_message.all_content = current_content - if delta_message.tool_calls: - for tool_call in delta_message.tool_calls: - if tool_call.id not in tool_calls_map: - tool_calls_map[tool_call.id] = llm_entities.ToolCall( - id=tool_call.id, - type=tool_call.type, - function=llm_entities.FunctionCall( - name=tool_call.function.name if tool_call.function else '', - arguments='' - ), - ) - if tool_call.function and tool_call.function.arguments: - # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 - tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments - - - chunk_idx += 1 - chunk_choices = getattr(chunk, 'choices', None) - if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): - delta_message.is_final = True - delta_message.content = current_content - - if chunk_idx % 64 == 0 or delta_message.is_final: - - yield delta_message - # return + current_content = '' + args['stream'] = True + chunk_idx = 0 + self.is_content = False + tool_calls_map: dict[str, llm_entities.ToolCall] = {} + pipeline_config = query.pipeline_config + async for chunk in self._req_stream(args, extra_body=extra_args): + # 处理流式消息 + delta_message = await self._make_msg_chunk(pipeline_config, chunk, chunk_idx) + if delta_message.content: + current_content += delta_message.content + delta_message.content = current_content + # delta_message.all_content = current_content + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + if tool_call.id not in tool_calls_map: + tool_calls_map[tool_call.id] = llm_entities.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=llm_entities.FunctionCall( + name=tool_call.function.name if tool_call.function else '', arguments='' + ), + ) + if tool_call.function and tool_call.function.arguments: + # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 + tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + chunk_idx += 1 + chunk_choices = getattr(chunk, 'choices', None) + if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): + delta_message.is_final = True + delta_message.content = current_content + if chunk_idx % 64 == 0 or delta_message.is_final: + yield delta_message + # return async def invoke_llm( self, @@ -340,16 +329,14 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') - async def invoke_llm_stream( self, query: core_entities.Query, model: requester.RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, - stream: bool = False, extra_args: dict[str, typing.Any] = {}, - ) -> llm_entities.MessageChunk: + ) -> llm_entities.MessageChunk: req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 for m in messages: msg_dict = m.dict(exclude_none=True) @@ -367,7 +354,6 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): req_messages=req_messages, use_model=model, use_funcs=funcs, - stream=stream, extra_args=extra_args, ): yield item @@ -386,4 +372,4 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): except openai.RateLimitError as e: raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: - raise errors.RequesterError(f'请求错误: {e.message}') \ No newline at end of file + raise errors.RequesterError(f'请求错误: {e.message}') diff --git a/pkg/provider/modelmgr/requesters/ppiochatcmpl.py b/pkg/provider/modelmgr/requesters/ppiochatcmpl.py index 85b321a7..46da6e01 100644 --- a/pkg/provider/modelmgr/requesters/ppiochatcmpl.py +++ b/pkg/provider/modelmgr/requesters/ppiochatcmpl.py @@ -5,8 +5,8 @@ import typing from . import chatcmpl import openai.types.chat.chat_completion as chat_completion -from .. import errors, requester -from ....core import entities as core_entities, app +from .. import requester +from ....core import entities as core_entities from ... import entities as llm_entities from ...tools import entities as tools_entities import re @@ -25,9 +25,9 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): is_think: bool = False async def _make_msg( - self, - chat_completion: chat_completion.ChatCompletion, - pipeline_config: dict[str, typing.Any] = {'trigger': {'misc': {'remove_think': False}}}, + self, + chat_completion: chat_completion.ChatCompletion, + pipeline_config: dict[str, typing.Any] = {'trigger': {'misc': {'remove_think': False}}}, ) -> llm_entities.Message: chatcmpl_message = chat_completion.choices[0].message.model_dump() # print(chatcmpl_message.keys(), chatcmpl_message.values()) @@ -40,21 +40,24 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): # deepseek的reasoner模型 if pipeline_config['trigger'].get('misc', '').get('remove_think'): - chatcmpl_message['content'] = re.sub(r'.*?', '', chatcmpl_message['content'], flags=re.DOTALL) + chatcmpl_message['content'] = re.sub( + r'.*?', '', chatcmpl_message['content'], flags=re.DOTALL + ) else: if reasoning_content is not None: - chatcmpl_message['content'] = '\n' + reasoning_content + '\n\n' + chatcmpl_message['content'] + chatcmpl_message['content'] = ( + '\n' + reasoning_content + '\n\n' + chatcmpl_message['content'] + ) message = llm_entities.Message(**chatcmpl_message) return message - async def _make_msg_chunk( - self, - pipeline_config: dict[str, typing.Any], - chat_completion: chat_completion.ChatCompletion, - idx: int, + self, + pipeline_config: dict[str, typing.Any], + chat_completion: chat_completion.ChatCompletion, + idx: int, ) -> llm_entities.MessageChunk: # 处理流式chunk和完整响应的差异 # print(chat_completion.choices[0]) @@ -80,7 +83,7 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): if '' in delta['content']: self.is_think = True delta['content'] = '' - if rf'' in delta['content']: + if r'' in delta['content']: self.is_think = False delta['content'] = '' if not self.is_think: @@ -95,15 +98,13 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): return message - async def _closure_stream( - self, - query: core_entities.Query, - req_messages: list[dict], - use_model: requester.RuntimeLLMModel, - use_funcs: list[tools_entities.LLMFunction] = None, - stream: bool = False, - extra_args: dict[str, typing.Any] = {}, + self, + query: core_entities.Query, + req_messages: list[dict], + use_model: requester.RuntimeLLMModel, + use_funcs: list[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: self.client.api_key = use_model.token_mgr.get_token() @@ -130,40 +131,38 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): args['messages'] = messages - if stream: - current_content = '' - args["stream"] = True - chunk_idx = 0 - self.is_content = False - tool_calls_map: dict[str, llm_entities.ToolCall] = {} - pipeline_config = query.pipeline_config - async for chunk in self._req_stream(args, extra_body=extra_args): - # 处理流式消息 - delta_message = await self._make_msg_chunk(pipeline_config, chunk, chunk_idx) - if delta_message.content: - current_content += delta_message.content - delta_message.content = current_content - # delta_message.all_content = current_content - if delta_message.tool_calls: - for tool_call in delta_message.tool_calls: - if tool_call.id not in tool_calls_map: - tool_calls_map[tool_call.id] = llm_entities.ToolCall( - id=tool_call.id, - type=tool_call.type, - function=llm_entities.FunctionCall( - name=tool_call.function.name if tool_call.function else '', - arguments='' - ), - ) - if tool_call.function and tool_call.function.arguments: - # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 - tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + current_content = '' + args['stream'] = True + chunk_idx = 0 + self.is_content = False + tool_calls_map: dict[str, llm_entities.ToolCall] = {} + pipeline_config = query.pipeline_config + async for chunk in self._req_stream(args, extra_body=extra_args): + # 处理流式消息 + delta_message = await self._make_msg_chunk(pipeline_config, chunk, chunk_idx) + if delta_message.content: + current_content += delta_message.content + delta_message.content = current_content + # delta_message.all_content = current_content + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + if tool_call.id not in tool_calls_map: + tool_calls_map[tool_call.id] = llm_entities.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=llm_entities.FunctionCall( + name=tool_call.function.name if tool_call.function else '', arguments='' + ), + ) + if tool_call.function and tool_call.function.arguments: + # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 + tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments - chunk_idx += 1 - chunk_choices = getattr(chunk, 'choices', None) - if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): - delta_message.is_final = True - delta_message.content = current_content + chunk_idx += 1 + chunk_choices = getattr(chunk, 'choices', None) + if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): + delta_message.is_final = True + delta_message.content = current_content - if chunk_idx % 64 == 0 or delta_message.is_final: - yield delta_message + if chunk_idx % 64 == 0 or delta_message.is_final: + yield delta_message diff --git a/pkg/provider/runners/difysvapi.py b/pkg/provider/runners/difysvapi.py index 8182cc54..40a3140c 100644 --- a/pkg/provider/runners/difysvapi.py +++ b/pkg/provider/runners/difysvapi.py @@ -348,7 +348,9 @@ class DifyServiceAPIRunner(runner.RequestRunner): except AttributeError: is_stream = False - batch_pending_index = 0 + _ = is_stream + + # batch_pending_index = 0 plain_text, image_ids = await self._preprocess_user_message(query) diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index 599b0b08..3ff0ce9d 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -128,8 +128,7 @@ class LocalAgentRunner(runner.RequestRunner): id=tool_call.id, type=tool_call.type, function=llm_entities.FunctionCall( - name=tool_call.function.name if tool_call.function else '', - arguments='' + name=tool_call.function.name if tool_call.function else '', arguments='' ), ) if tool_call.function and tool_call.function.arguments: diff --git a/pkg/utils/image.py b/pkg/utils/image.py index f69d29d2..d9518e12 100644 --- a/pkg/utils/image.py +++ b/pkg/utils/image.py @@ -204,9 +204,9 @@ async def get_slack_image_to_base64(pic_url: str, bot_token: str): try: async with aiohttp.ClientSession() as session: async with session.get(pic_url, headers=headers) as resp: - mime_type = resp.headers.get("Content-Type", "application/octet-stream") + mime_type = resp.headers.get('Content-Type', 'application/octet-stream') file_bytes = await resp.read() - base64_str = base64.b64encode(file_bytes).decode("utf-8") - return f"data:{mime_type};base64,{base64_str}" + base64_str = base64.b64encode(file_bytes).decode('utf-8') + return f'data:{mime_type};base64,{base64_str}' except Exception as e: - raise (e) \ No newline at end of file + raise (e) diff --git a/pkg/utils/importutil.py b/pkg/utils/importutil.py index 8acc5c45..1933d611 100644 --- a/pkg/utils/importutil.py +++ b/pkg/utils/importutil.py @@ -32,7 +32,7 @@ def import_dir(path: str): rel_path = full_path.replace(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '') rel_path = rel_path[1:] rel_path = rel_path.replace('/', '.')[:-3] - rel_path = rel_path.replace("\\",".") + rel_path = rel_path.replace('\\', '.') importlib.import_module(rel_path)