diff --git a/libs/dingtalk_api/api.py b/libs/dingtalk_api/api.py index d323df1e..3d483a3a 100644 --- a/libs/dingtalk_api/api.py +++ b/libs/dingtalk_api/api.py @@ -253,6 +253,43 @@ class DingTalkClient: await self.logger.error(f'failed to send proactive massage to group: {traceback.format_exc()}') raise Exception(f'failed to send proactive massage to group: {traceback.format_exc()}') + async def create_and_card( + self, temp_card_id: str, incoming_message: dingtalk_stream.ChatbotMessage, quote_origin: bool = False + ): + content_key = 'content' + card_data = {content_key: ''} + + card_instance = dingtalk_stream.AICardReplier(self.client, incoming_message) + # print(card_instance) + # 先投放卡片: https://open.dingtalk.com/document/orgapp/create-and-deliver-cards + card_instance_id = await card_instance.async_create_and_deliver_card( + temp_card_id, + card_data, + ) + return card_instance, card_instance_id + + async def send_card_message(self, card_instance, card_instance_id: str, content: str, is_final: bool): + content_key = 'content' + try: + await card_instance.async_streaming( + card_instance_id, + content_key=content_key, + content_value=content, + append=False, + finished=is_final, + failed=False, + ) + except Exception as e: + self.logger.exception(e) + await card_instance.async_streaming( + card_instance_id, + content_key=content_key, + content_value='', + append=False, + finished=is_final, + failed=True, + ) + async def start(self): """启动 WebSocket 连接,监听消息""" await self.client.start() diff --git a/pkg/api/http/controller/groups/pipelines/webchat.py b/pkg/api/http/controller/groups/pipelines/webchat.py index c8c8db54..7eea471a 100644 --- a/pkg/api/http/controller/groups/pipelines/webchat.py +++ b/pkg/api/http/controller/groups/pipelines/webchat.py @@ -1,3 +1,5 @@ +import json + import quart from ... import group @@ -9,10 +11,18 @@ class WebChatDebugRouterGroup(group.RouterGroup): @self.route('/send', methods=['POST']) async def send_message(pipeline_uuid: str) -> str: """Send a message to the pipeline for debugging""" + + async def stream_generator(generator): + yield 'data: {"type": "start"}\n\n' + async for message in generator: + yield f'data: {json.dumps({"message": message})}\n\n' + yield 'data: {"type": "end"}\n\n' + try: data = await quart.request.get_json() session_type = data.get('session_type', 'person') message_chain_obj = data.get('message', []) + is_stream = data.get('is_stream', False) if not message_chain_obj: return self.http_status(400, -1, 'message is required') @@ -25,13 +35,33 @@ class WebChatDebugRouterGroup(group.RouterGroup): if not webchat_adapter: return self.http_status(404, -1, 'WebChat adapter not found') - result = await webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj) - - return self.success( - data={ - 'message': result, + if is_stream: + generator = webchat_adapter.send_webchat_message( + pipeline_uuid, session_type, message_chain_obj, is_stream + ) + # 设置正确的响应头 + headers = { + 'Content-Type': 'text/event-stream', + 'Transfer-Encoding': 'chunked', + 'Cache-Control': 'no-cache', + 'Connection': 'keep-alive' } - ) + return quart.Response(stream_generator(generator), mimetype='text/event-stream',headers=headers) + + else: # non-stream + result = None + async for message in webchat_adapter.send_webchat_message( + pipeline_uuid, session_type, message_chain_obj + ): + result = message + if result is not None: + return self.success( + data={ + 'message': result, + } + ) + else: + return self.http_status(400, -1, 'message is required') except Exception as e: return self.http_status(500, -1, f'Internal server error: {str(e)}') diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index d8457da3..d3f3d5d8 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -101,7 +101,7 @@ class LLMModelsService: model=runtime_llm_model, messages=[llm_entities.Message(role='user', content='Hello, world!')], funcs=[], - extra_args={}, + extra_args=model_data.get('extra_args', {}), ) diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 4caf18ed..1efee3fc 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -87,7 +87,9 @@ class Query(pydantic.BaseModel): """使用的函数,由前置处理器阶段设置""" resp_messages: ( - typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] + typing.Optional[list[llm_entities.Message]] + | typing.Optional[list[platform_message.MessageChain]] + | typing.Optional[list[llm_entities.MessageChunk]] ) = [] """由Process阶段生成的回复消息对象列表""" diff --git a/pkg/persistence/migrations/dbm005_pipeline_remove_cot_config.py b/pkg/persistence/migrations/dbm005_pipeline_remove_cot_config.py new file mode 100644 index 00000000..14f0beec --- /dev/null +++ b/pkg/persistence/migrations/dbm005_pipeline_remove_cot_config.py @@ -0,0 +1,38 @@ +from .. import migration + +import sqlalchemy + +from ...entity.persistence import pipeline as persistence_pipeline + + +@migration.migration_class(5) +class DBMigratePipelineRemoveCotConfig(migration.DBMigration): + """Pipeline remove cot config""" + + async def upgrade(self): + """Upgrade""" + # read all pipelines + pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline)) + + for pipeline in pipelines: + serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline) + + config = serialized_pipeline['config'] + + if 'remove-think' not in config['output']['misc']: + config['output']['misc']['remove-think'] = True + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_pipeline.LegacyPipeline) + .where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid']) + .values( + { + 'config': config, + 'for_version': self.ap.ver_mgr.get_current_version(), + } + ) + ) + + async def downgrade(self): + """Downgrade""" + pass diff --git a/pkg/pipeline/pipelinemgr.py b/pkg/pipeline/pipelinemgr.py index 77df09dc..abf80e16 100644 --- a/pkg/pipeline/pipelinemgr.py +++ b/pkg/pipeline/pipelinemgr.py @@ -93,12 +93,20 @@ class RuntimePipeline: query.message_event, platform_events.GroupMessage ): result.user_notice.insert(0, platform_message.At(query.message_event.sender.id)) - - await query.adapter.reply_message( - message_source=query.message_event, - message=result.user_notice, - quote_origin=query.pipeline_config['output']['misc']['quote-origin'], - ) + if await query.adapter.is_stream_output_supported(): + await query.adapter.reply_message_chunk( + message_source=query.message_event, + bot_message=query.resp_messages[-1], + message=result.user_notice, + quote_origin=query.pipeline_config['output']['misc']['quote-origin'], + is_final=[msg.is_final for msg in query.resp_messages][0] + ) + else: + await query.adapter.reply_message( + message_source=query.message_event, + message=result.user_notice, + quote_origin=query.pipeline_config['output']['misc']['quote-origin'], + ) if result.debug_notice: self.ap.logger.debug(result.debug_notice) if result.console_notice: diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 2aa08e17..e913bbc2 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -1,5 +1,6 @@ from __future__ import annotations +import uuid import typing import traceback @@ -22,11 +23,11 @@ class ChatMessageHandler(handler.MessageHandler): self, query: core_entities.Query, ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: - """Process""" - # Call API - # generator + """处理""" + # 调API + # 生成器 - # Trigger plugin event + # 触发插件事件 event_class = ( events.PersonNormalMessageReceived if query.launcher_type == core_entities.LauncherTypes.PERSON @@ -46,7 +47,6 @@ class ChatMessageHandler(handler.MessageHandler): if event_ctx.is_prevented_default(): if event_ctx.event.reply is not None: mc = platform_message.MessageChain(event_ctx.event.reply) - query.resp_messages.append(mc) yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) @@ -54,10 +54,14 @@ class ChatMessageHandler(handler.MessageHandler): yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query) else: if event_ctx.event.alter is not None: - # if isinstance(event_ctx.event, str): # Currently not considering multi-modal alter + # if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter query.user_message.content = event_ctx.event.alter text_length = 0 + try: + is_stream = await query.adapter.is_stream_output_supported() + except AttributeError: + is_stream = False try: for r in runner_module.preregistered_runners: @@ -65,22 +69,42 @@ class ChatMessageHandler(handler.MessageHandler): runner = r(self.ap, query.pipeline_config) break else: - raise ValueError(f'Request runner not found: {query.pipeline_config["ai"]["runner"]["runner"]}') + raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}') + if is_stream: + resp_message_id = uuid.uuid4() + await query.adapter.create_message_card(str(resp_message_id), query.message_event) + async for result in runner.run(query): + result.resp_message_id = str(resp_message_id) + if query.resp_messages: + query.resp_messages.pop() + if query.resp_message_chain: + query.resp_message_chain.pop() - async for result in runner.run(query): - query.resp_messages.append(result) + query.resp_messages.append(result) + self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}') - self.ap.logger.info(f'Response({query.query_id}): {self.cut_str(result.readable_str())}') + if result.content is not None: + text_length += len(result.content) - if result.content is not None: - text_length += len(result.content) + yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) - yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) + else: + async for result in runner.run(query): + query.resp_messages.append(result) + + self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') + + if result.content is not None: + text_length += len(result.content) + + yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) query.session.using_conversation.messages.append(query.user_message) + query.session.using_conversation.messages.extend(query.resp_messages) except Exception as e: - self.ap.logger.error(f'Request failed({query.query_id}): {type(e).__name__} {str(e)}') + self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}') + traceback.print_exc() hide_exception_info = query.pipeline_config['output']['misc']['hide-exception'] @@ -93,4 +117,4 @@ class ChatMessageHandler(handler.MessageHandler): ) finally: # TODO statistics - pass + pass \ No newline at end of file diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index 39d3abb1..ece4e392 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -7,6 +7,10 @@ import asyncio from ...platform.types import events as platform_events from ...platform.types import message as platform_message +from ...provider import entities as llm_entities + + + from .. import stage, entities from ...core import entities as core_entities @@ -36,10 +40,22 @@ class SendResponseBackStage(stage.PipelineStage): quote_origin = query.pipeline_config['output']['misc']['quote-origin'] - await query.adapter.reply_message( - message_source=query.message_event, - message=query.resp_message_chain[-1], - quote_origin=quote_origin, - ) + has_chunks = any(isinstance(msg, llm_entities.MessageChunk) for msg in query.resp_messages) + # TODO 命令与流式的兼容性问题 + if await query.adapter.is_stream_output_supported() and has_chunks: + is_final = [msg.is_final for msg in query.resp_messages][0] + await query.adapter.reply_message_chunk( + message_source=query.message_event, + bot_message=query.resp_messages[-1], + message=query.resp_message_chain[-1], + quote_origin=quote_origin, + is_final=is_final, + ) + else: + await query.adapter.reply_message( + message_source=query.message_event, + message=query.resp_message_chain[-1], + quote_origin=quote_origin, + ) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py index f28ad3dc..e064ef80 100644 --- a/pkg/platform/adapter.py +++ b/pkg/platform/adapter.py @@ -61,14 +61,40 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): """ raise NotImplementedError + async def reply_message_chunk( + self, + message_source: platform_events.MessageEvent, + bot_message: dict, + message: platform_message.MessageChain, + quote_origin: bool = False, + is_final: bool = False, + ): + """回复消息(流式输出) + Args: + message_source (platform.types.MessageEvent): 消息源事件 + message_id (int): 消息ID + message (platform.types.MessageChain): 消息链 + quote_origin (bool, optional): 是否引用原消息. Defaults to False. + is_final (bool, optional): 流式是否结束. Defaults to False. + """ + raise NotImplementedError + + async def create_message_card(self, message_id: typing.Type[str, int], event: platform_events.MessageEvent) -> bool: + """创建卡片消息 + Args: + message_id (str): 消息ID + event (platform_events.MessageEvent): 消息源事件 + """ + return False + async def is_muted(self, group_id: int) -> bool: """获取账号是否在指定群被禁言""" raise NotImplementedError def register_listener( self, - event_type: typing.Type[platform_message.Event], - callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None], + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], ): """注册事件监听器 @@ -80,8 +106,8 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): def unregister_listener( self, - event_type: typing.Type[platform_message.Event], - callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None], + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], ): """注销事件监听器 @@ -95,6 +121,10 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): """异步运行""" raise NotImplementedError + async def is_stream_output_supported(self) -> bool: + """是否支持流式输出""" + return False + async def kill(self) -> bool: """关闭适配器 @@ -136,7 +166,7 @@ class EventConverter: """事件转换器基类""" @staticmethod - def yiri2target(event: typing.Type[platform_message.Event]): + def yiri2target(event: typing.Type[platform_events.Event]): """将源平台事件转换为目标平台事件 Args: @@ -148,7 +178,7 @@ class EventConverter: raise NotImplementedError @staticmethod - def target2yiri(event: typing.Any) -> platform_message.Event: + def target2yiri(event: typing.Any) -> platform_events.Event: """将目标平台事件的调用参数转换为源平台的事件参数对象 Args: diff --git a/pkg/platform/botmgr.py b/pkg/platform/botmgr.py index 5855525f..1da5eec8 100644 --- a/pkg/platform/botmgr.py +++ b/pkg/platform/botmgr.py @@ -120,8 +120,10 @@ class RuntimeBot: if isinstance(e, asyncio.CancelledError): self.task_context.set_current_action('Exited.') return + + traceback_str = traceback.format_exc() self.task_context.set_current_action('Exited with error.') - await self.logger.error(f'平台适配器运行出错:\n{e}\n{traceback.format_exc()}') + await self.logger.error(f'平台适配器运行出错:\n{e}\n{traceback_str}') self.task_wrapper = self.ap.task_mgr.create_task( exception_wrapper(), diff --git a/pkg/platform/sources/dingtalk.py b/pkg/platform/sources/dingtalk.py index a40b0f9b..71c7b0d0 100644 --- a/pkg/platform/sources/dingtalk.py +++ b/pkg/platform/sources/dingtalk.py @@ -1,3 +1,4 @@ +from re import S import traceback import typing from libs.dingtalk_api.dingtalkevent import DingTalkEvent @@ -99,11 +100,15 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): message_converter: DingTalkMessageConverter = DingTalkMessageConverter() event_converter: DingTalkEventConverter = DingTalkEventConverter() config: dict + card_instance_id_dict: dict # 回复卡片消息字典,key为消息id,value为回复卡片实例id,用于在流式消息时判断是否发送到指定卡片 + seq: int # 消息顺序,直接以seq作为标识 def __init__(self, config: dict, ap: app.Application, logger: EventLogger): self.config = config self.ap = ap self.logger = logger + self.card_instance_id_dict = {} + # self.seq = 1 required_keys = [ 'client_id', 'client_secret', @@ -139,6 +144,34 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): content, at = await DingTalkMessageConverter.yiri2target(message) await self.bot.send_message(content, incoming_message, at) + async def reply_message_chunk( + self, + message_source: platform_events.MessageEvent, + bot_message, + message: platform_message.MessageChain, + quote_origin: bool = False, + is_final: bool = False, + ): + # event = await DingTalkEventConverter.yiri2target( + # message_source, + # ) + # incoming_message = event.incoming_message + + # msg_id = incoming_message.message_id + message_id = bot_message.resp_message_id + msg_seq = bot_message.msg_sequence + + if (msg_seq - 1) % 8 == 0 or is_final: + + content, at = await DingTalkMessageConverter.yiri2target(message) + + card_instance, card_instance_id = self.card_instance_id_dict[message_id] + # print(card_instance_id) + await self.bot.send_card_message(card_instance, card_instance_id, content, is_final) + if is_final and bot_message.tool_calls is None: + # self.seq = 1 # 消息回复结束之后重置seq + self.card_instance_id_dict.pop(message_id) # 消息回复结束之后删除卡片实例id + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): content = await DingTalkMessageConverter.yiri2target(message) if target_type == 'person': @@ -146,6 +179,20 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): if target_type == 'group': await self.bot.send_proactive_message_to_group(target_id, content) + async def is_stream_output_supported(self) -> bool: + is_stream = False + if self.config.get('enable-stream-reply', None): + is_stream = True + return is_stream + + async def create_message_card(self, message_id, event): + card_template_id = self.config['card_template_id'] + incoming_message = event.source_platform_object.incoming_message + # message_id = incoming_message.message_id + card_instance, card_instance_id = await self.bot.create_and_card(card_template_id, incoming_message) + self.card_instance_id_dict[message_id] = (card_instance, card_instance_id) + return True + def register_listener( self, event_type: typing.Type[platform_events.Event], diff --git a/pkg/platform/sources/dingtalk.yaml b/pkg/platform/sources/dingtalk.yaml index fac2d6ff..70855c2b 100644 --- a/pkg/platform/sources/dingtalk.yaml +++ b/pkg/platform/sources/dingtalk.yaml @@ -46,6 +46,23 @@ spec: type: boolean required: false default: true + - name: enable-stream-reply + label: + en_US: Enable Stream Reply Mode + zh_Hans: 启用钉钉卡片流式回复模式 + description: + en_US: If enabled, the bot will use the stream of lark reply mode + zh_Hans: 如果启用,将使用钉钉卡片流式方式来回复内容 + type: boolean + required: true + default: false + - name: card_template_id + label: + en_US: card template id + zh_Hans: 卡片模板ID + type: string + required: true + default: "填写你的卡片template_id" execution: python: path: ./dingtalk.py diff --git a/pkg/platform/sources/discord.py b/pkg/platform/sources/discord.py index b55ee0bb..9e26f239 100644 --- a/pkg/platform/sources/discord.py +++ b/pkg/platform/sources/discord.py @@ -8,7 +8,6 @@ import base64 import uuid import os import datetime -import io import asyncio from enum import Enum diff --git a/pkg/platform/sources/lark.py b/pkg/platform/sources/lark.py index f8faf522..975730b5 100644 --- a/pkg/platform/sources/lark.py +++ b/pkg/platform/sources/lark.py @@ -17,6 +17,7 @@ import aiohttp import lark_oapi.ws.exception import quart from lark_oapi.api.im.v1 import * +from lark_oapi.api.cardkit.v1 import * from .. import adapter from ...core import app @@ -320,6 +321,10 @@ class LarkEventConverter(adapter.EventConverter): ) +CARD_ID_CACHE_SIZE = 500 +CARD_ID_CACHE_MAX_LIFETIME = 20 * 60 # 20分钟 + + class LarkAdapter(adapter.MessagePlatformAdapter): bot: lark_oapi.ws.Client api_client: lark_oapi.Client @@ -339,12 +344,20 @@ class LarkAdapter(adapter.MessagePlatformAdapter): quart_app: quart.Quart ap: app.Application + + card_id_dict: dict[str, str] # 消息id到卡片id的映射,便于创建卡片后的发送消息到指定卡片 + + seq: int # 用于在发送卡片消息中识别消息顺序,直接以seq作为标识 + def __init__(self, config: dict, ap: app.Application, logger: EventLogger): self.config = config self.ap = ap self.logger = logger self.quart_app = quart.Quart(__name__) self.listeners = {} + self.card_id_dict = {} + self.seq = 1 + @self.quart_app.route('/lark/callback', methods=['POST']) async def lark_callback(): @@ -409,6 +422,216 @@ class LarkAdapter(adapter.MessagePlatformAdapter): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): pass + async def is_stream_output_supported(self) -> bool: + is_stream = False + if self.config.get('enable-stream-reply', None): + is_stream = True + return is_stream + + async def create_card_id(self, message_id): + try: + self.ap.logger.debug('飞书支持stream输出,创建卡片......') + + card_data = {"schema": "2.0", "config": {"update_multi": True, "streaming_mode": True, + "streaming_config": {"print_step": {"default": 1}, + "print_frequency_ms": {"default": 70}, + "print_strategy": "fast"}}, + "body": {"direction": "vertical", "padding": "12px 12px 12px 12px", "elements": [{"tag": "div", + "text": { + "tag": "plain_text", + "content": "LangBot", + "text_size": "normal", + "text_align": "left", + "text_color": "default"}, + "icon": { + "tag": "custom_icon", + "img_key": "img_v3_02p3_05c65d5d-9bad-440a-a2fb-c89571bfd5bg"}}, + { + "tag": "markdown", + "content": "", + "text_align": "left", + "text_size": "normal", + "margin": "0px 0px 0px 0px", + "element_id": "streaming_txt"}, + { + "tag": "markdown", + "content": "", + "text_align": "left", + "text_size": "normal", + "margin": "0px 0px 0px 0px"}, + { + "tag": "column_set", + "horizontal_spacing": "8px", + "horizontal_align": "left", + "columns": [ + { + "tag": "column", + "width": "weighted", + "elements": [ + { + "tag": "markdown", + "content": "", + "text_align": "left", + "text_size": "normal", + "margin": "0px 0px 0px 0px"}, + { + "tag": "markdown", + "content": "", + "text_align": "left", + "text_size": "normal", + "margin": "0px 0px 0px 0px"}, + { + "tag": "markdown", + "content": "", + "text_align": "left", + "text_size": "normal", + "margin": "0px 0px 0px 0px"}], + "padding": "0px 0px 0px 0px", + "direction": "vertical", + "horizontal_spacing": "8px", + "vertical_spacing": "2px", + "horizontal_align": "left", + "vertical_align": "top", + "margin": "0px 0px 0px 0px", + "weight": 1}], + "margin": "0px 0px 0px 0px"}, + {"tag": "hr", + "margin": "0px 0px 0px 0px"}, + { + "tag": "column_set", + "horizontal_spacing": "12px", + "horizontal_align": "right", + "columns": [ + { + "tag": "column", + "width": "weighted", + "elements": [ + { + "tag": "markdown", + "content": "以上内容由 AI 生成,仅供参考。更多详细、准确信息可点击引用链接查看", + "text_align": "left", + "text_size": "notation", + "margin": "4px 0px 0px 0px", + "icon": { + "tag": "standard_icon", + "token": "robot_outlined", + "color": "grey"}}], + "padding": "0px 0px 0px 0px", + "direction": "vertical", + "horizontal_spacing": "8px", + "vertical_spacing": "8px", + "horizontal_align": "left", + "vertical_align": "top", + "margin": "0px 0px 0px 0px", + "weight": 1}, + { + "tag": "column", + "width": "20px", + "elements": [ + { + "tag": "button", + "text": { + "tag": "plain_text", + "content": ""}, + "type": "text", + "width": "fill", + "size": "medium", + "icon": { + "tag": "standard_icon", + "token": "thumbsup_outlined"}, + "hover_tips": { + "tag": "plain_text", + "content": "有帮助"}, + "margin": "0px 0px 0px 0px"}], + "padding": "0px 0px 0px 0px", + "direction": "vertical", + "horizontal_spacing": "8px", + "vertical_spacing": "8px", + "horizontal_align": "left", + "vertical_align": "top", + "margin": "0px 0px 0px 0px"}, + { + "tag": "column", + "width": "30px", + "elements": [ + { + "tag": "button", + "text": { + "tag": "plain_text", + "content": ""}, + "type": "text", + "width": "default", + "size": "medium", + "icon": { + "tag": "standard_icon", + "token": "thumbdown_outlined"}, + "hover_tips": { + "tag": "plain_text", + "content": "无帮助"}, + "margin": "0px 0px 0px 0px"}], + "padding": "0px 0px 0px 0px", + "vertical_spacing": "8px", + "horizontal_align": "left", + "vertical_align": "top", + "margin": "0px 0px 0px 0px"}], + "margin": "0px 0px 4px 0px"}]}} + # delay / fast 创建卡片模板,delay 延迟打印,fast 实时打印,可以自定义更好看的消息模板 + + request: CreateCardRequest = ( + CreateCardRequest.builder() + .request_body(CreateCardRequestBody.builder().type('card_json').data(json.dumps(card_data)).build()) + .build() + ) + + # 发起请求 + response: CreateCardResponse = self.api_client.cardkit.v1.card.create(request) + + # 处理失败返回 + if not response.success(): + raise Exception( + f'client.cardkit.v1.card.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}' + ) + + self.ap.logger.debug(f'飞书卡片创建成功,卡片ID: {response.data.card_id}') + self.card_id_dict[message_id] = response.data.card_id + + card_id = response.data.card_id + return card_id + + except Exception as e: + self.ap.logger.error(f'飞书卡片创建失败,错误信息: {e}') + + async def create_message_card(self, message_id, event) -> str: + """ + 创建卡片消息。 + 使用卡片消息是因为普通消息更新次数有限制,而大模型流式返回结果可能很多而超过限制,而飞书卡片没有这个限制(api免费次数有限) + """ + # message_id = event.message_chain.message_id + + card_id = await self.create_card_id(message_id) + content = { + 'type': 'card', + 'data': {'card_id': card_id, 'template_variable': {'content': 'Thinking...'}}, + } # 当收到消息时发送消息模板,可添加模板变量,详情查看飞书中接口文档 + request: ReplyMessageRequest = ( + ReplyMessageRequest.builder() + .message_id(event.message_chain.message_id) + .request_body( + ReplyMessageRequestBody.builder().content(json.dumps(content)).msg_type('interactive').build() + ) + .build() + ) + + # 发起请求 + response: ReplyMessageResponse = await self.api_client.im.v1.message.areply(request) + + # 处理失败返回 + if not response.success(): + raise Exception( + f'client.im.v1.message.reply failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}' + ) + return True + async def reply_message( self, message_source: platform_events.MessageEvent, @@ -447,6 +670,64 @@ class LarkAdapter(adapter.MessagePlatformAdapter): f'client.im.v1.message.reply failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}' ) + async def reply_message_chunk( + self, + message_source: platform_events.MessageEvent, + bot_message, + message: platform_message.MessageChain, + quote_origin: bool = False, + is_final: bool = False, + ): + """ + 回复消息变成更新卡片消息 + """ + # self.seq += 1 + message_id = bot_message.resp_message_id + msg_seq = bot_message.msg_sequence + if msg_seq % 8 == 0 or is_final: + + lark_message = await self.message_converter.yiri2target(message, self.api_client) + + + text_message = '' + for ele in lark_message[0]: + if ele['tag'] == 'text': + text_message += ele['text'] + elif ele['tag'] == 'md': + text_message += ele['text'] + + # content = { + # 'type': 'card_json', + # 'data': {'card_id': self.card_id_dict[message_id], 'elements': {'content': text_message}}, + # } + + request: ContentCardElementRequest = ( + ContentCardElementRequest.builder() + .card_id(self.card_id_dict[message_id]) + .element_id('streaming_txt') + .request_body( + ContentCardElementRequestBody.builder() + # .uuid("a0d69e20-1dd1-458b-k525-dfeca4015204") + .content(text_message) + .sequence(msg_seq) + .build() + ) + .build() + ) + + if is_final and bot_message.tool_calls is None: + # self.seq = 1 # 消息回复结束之后重置seq + self.card_id_dict.pop(message_id) # 清理已经使用过的卡片 + # 发起请求 + response: ContentCardElementResponse = self.api_client.cardkit.v1.card_element.content(request) + + # 处理失败返回 + if not response.success(): + raise Exception( + f'client.im.v1.message.patch failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}' + ) + return + async def is_muted(self, group_id: int) -> bool: return False @@ -492,4 +773,9 @@ class LarkAdapter(adapter.MessagePlatformAdapter): ) async def kill(self) -> bool: - return False + # 需要断开连接,不然旧的连接会继续运行,导致飞书消息来时会随机选择一个连接 + # 断开时lark.ws.Client的_receive_message_loop会打印error日志: receive message loop exit。然后进行重连, + # 所以要设置_auto_reconnect=False,让其不重连。 + self.bot._auto_reconnect = False + await self.bot._disconnect() + return False \ No newline at end of file diff --git a/pkg/platform/sources/lark.yaml b/pkg/platform/sources/lark.yaml index f51bab76..94414b2e 100644 --- a/pkg/platform/sources/lark.yaml +++ b/pkg/platform/sources/lark.yaml @@ -65,6 +65,16 @@ spec: type: string required: true default: "" + - name: enable-stream-reply + label: + en_US: Enable Stream Reply Mode + zh_Hans: 启用飞书流式回复模式 + description: + en_US: If enabled, the bot will use the stream of lark reply mode + zh_Hans: 如果启用,将使用飞书流式方式来回复内容 + type: boolean + required: true + default: false execution: python: path: ./lark.py diff --git a/pkg/platform/sources/telegram.py b/pkg/platform/sources/telegram.py index c2fcc22e..8aee12d7 100644 --- a/pkg/platform/sources/telegram.py +++ b/pkg/platform/sources/telegram.py @@ -1,5 +1,6 @@ from __future__ import annotations + import telegram import telegram.ext from telegram import Update @@ -143,6 +144,10 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): config: dict ap: app.Application + msg_stream_id: dict # 流式消息id字典,key为流式消息id,value为首次消息源id,用于在流式消息时判断编辑那条消息 + + seq: int # 消息中识别消息顺序,直接以seq作为标识 + listeners: typing.Dict[ typing.Type[platform_events.Event], typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], @@ -152,6 +157,8 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): self.config = config self.ap = ap self.logger = logger + self.msg_stream_id = {} + # self.seq = 1 async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE): if update.message.from_user.is_bot: @@ -160,6 +167,7 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): try: lb_event = await self.event_converter.target2yiri(update, self.bot, self.bot_account_id) await self.listeners[type(lb_event)](lb_event, self) + await self.is_stream_output_supported() except Exception: await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}') @@ -200,6 +208,70 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): await self.bot.send_message(**args) + async def reply_message_chunk( + self, + message_source: platform_events.MessageEvent, + bot_message, + message: platform_message.MessageChain, + quote_origin: bool = False, + is_final: bool = False, + ): + msg_seq = bot_message.msg_sequence + if (msg_seq - 1) % 8 == 0 or is_final: + assert isinstance(message_source.source_platform_object, Update) + components = await TelegramMessageConverter.yiri2target(message, self.bot) + args = {} + message_id = message_source.source_platform_object.message.id + if quote_origin: + args['reply_to_message_id'] = message_source.source_platform_object.message.id + + component = components[0] + if message_id not in self.msg_stream_id: # 当消息回复第一次时,发送新消息 + # time.sleep(0.6) + if component['type'] == 'text': + if self.config['markdown_card'] is True: + content = telegramify_markdown.markdownify( + content=component['text'], + ) + else: + content = component['text'] + args = { + 'chat_id': message_source.source_platform_object.effective_chat.id, + 'text': content, + } + if self.config['markdown_card'] is True: + args['parse_mode'] = 'MarkdownV2' + + send_msg = await self.bot.send_message(**args) + send_msg_id = send_msg.message_id + self.msg_stream_id[message_id] = send_msg_id + else: # 存在消息的时候直接编辑消息1 + if component['type'] == 'text': + if self.config['markdown_card'] is True: + content = telegramify_markdown.markdownify( + content=component['text'], + ) + else: + content = component['text'] + args = { + 'message_id': self.msg_stream_id[message_id], + 'chat_id': message_source.source_platform_object.effective_chat.id, + 'text': content, + } + if self.config['markdown_card'] is True: + args['parse_mode'] = 'MarkdownV2' + + await self.bot.edit_message_text(**args) + if is_final and bot_message.tool_calls is None: + # self.seq = 1 # 消息回复结束之后重置seq + self.msg_stream_id.pop(message_id) # 消息回复结束之后删除流式消息id + + async def is_stream_output_supported(self) -> bool: + is_stream = False + if self.config.get('enable-stream-reply', None): + is_stream = True + return is_stream + async def is_muted(self, group_id: int) -> bool: return False @@ -222,8 +294,12 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = (await self.bot.get_me()).username await self.application.updater.start_polling(allowed_updates=Update.ALL_TYPES) await self.application.start() + await self.logger.info('Telegram adapter running') async def kill(self) -> bool: if self.application.running: await self.application.stop() + if self.application.updater: + await self.application.updater.stop() + await self.logger.info('Telegram adapter stopped') return True diff --git a/pkg/platform/sources/telegram.yaml b/pkg/platform/sources/telegram.yaml index 43b9284b..d29c359e 100644 --- a/pkg/platform/sources/telegram.yaml +++ b/pkg/platform/sources/telegram.yaml @@ -25,6 +25,16 @@ spec: type: boolean required: false default: true + - name: enable-stream-reply + label: + en_US: Enable Stream Reply Mode + zh_Hans: 启用电报流式回复模式 + description: + en_US: If enabled, the bot will use the stream of telegram reply mode + zh_Hans: 如果启用,将使用电报流式方式来回复内容 + type: boolean + required: true + default: false execution: python: path: ./telegram.py diff --git a/pkg/platform/sources/webchat.py b/pkg/platform/sources/webchat.py index 51b0479f..c43c4628 100644 --- a/pkg/platform/sources/webchat.py +++ b/pkg/platform/sources/webchat.py @@ -19,17 +19,20 @@ class WebChatMessage(BaseModel): content: str message_chain: list[dict] timestamp: str + is_final: bool = False class WebChatSession: id: str message_lists: dict[str, list[WebChatMessage]] = {} resp_waiters: dict[int, asyncio.Future[WebChatMessage]] + resp_queues: dict[int, asyncio.Queue[WebChatMessage]] def __init__(self, id: str): self.id = id self.message_lists = {} self.resp_waiters = {} + self.resp_queues = {} def get_message_list(self, pipeline_uuid: str) -> list[WebChatMessage]: if pipeline_uuid not in self.message_lists: @@ -49,6 +52,8 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], None], ] = {} + is_stream: bool + def __init__(self, config: dict, ap: app.Application, logger: EventLogger): self.ap = ap self.logger = logger @@ -59,6 +64,8 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): self.bot_account_id = 'webchatbot' + self.is_stream = False + async def send_message( self, target_type: str, @@ -102,12 +109,53 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): # notify waiter if isinstance(message_source, platform_events.FriendMessage): - self.webchat_person_session.resp_waiters[message_source.message_chain.message_id].set_result(message_data) + await self.webchat_person_session.resp_queues[message_source.message_chain.message_id].put(message_data) elif isinstance(message_source, platform_events.GroupMessage): - self.webchat_group_session.resp_waiters[message_source.message_chain.message_id].set_result(message_data) + await self.webchat_group_session.resp_queues[message_source.message_chain.message_id].put(message_data) return message_data.model_dump() + async def reply_message_chunk( + self, + message_source: platform_events.MessageEvent, + bot_message, + message: platform_message.MessageChain, + quote_origin: bool = False, + is_final: bool = False, + ) -> dict: + """回复消息""" + message_data = WebChatMessage( + id=-1, + role='assistant', + content=str(message), + message_chain=[component.__dict__ for component in message], + timestamp=datetime.now().isoformat(), + ) + + # notify waiter + session = ( + self.webchat_group_session + if isinstance(message_source, platform_events.GroupMessage) + else self.webchat_person_session + ) + if message_source.message_chain.message_id not in session.resp_waiters: + # session.resp_waiters[message_source.message_chain.message_id] = asyncio.Queue() + queue = session.resp_queues[message_source.message_chain.message_id] + + # if isinstance(message_source, platform_events.FriendMessage): + # queue = self.webchat_person_session.resp_queues[message_source.message_chain.message_id] + # elif isinstance(message_source, platform_events.GroupMessage): + # queue = self.webchat_group_session.resp_queues[message_source.message_chain.message_id] + if is_final and bot_message.tool_calls is None: + message_data.is_final = True + # print(message_data) + await queue.put(message_data) + + return message_data.model_dump() + + async def is_stream_output_supported(self) -> bool: + return self.is_stream + def register_listener( self, event_type: typing.Type[platform_events.Event], @@ -140,8 +188,13 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): await self.logger.info('WebChat调试适配器正在停止') async def send_webchat_message( - self, pipeline_uuid: str, session_type: str, message_chain_obj: typing.List[dict] + self, + pipeline_uuid: str, + session_type: str, + message_chain_obj: typing.List[dict], + is_stream: bool = False, ) -> dict: + self.is_stream = is_stream """发送调试消息到流水线""" if session_type == 'person': use_session = self.webchat_person_session @@ -152,6 +205,9 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): message_id = len(use_session.get_message_list(pipeline_uuid)) + 1 + use_session.resp_queues[message_id] = asyncio.Queue() + logger.debug(f'Initialized queue for message_id: {message_id}') + use_session.get_message_list(pipeline_uuid).append( WebChatMessage( id=message_id, @@ -185,21 +241,46 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): self.ap.platform_mgr.webchat_proxy_bot.bot_entity.use_pipeline_uuid = pipeline_uuid + # trigger pipeline if event.__class__ in self.listeners: await self.listeners[event.__class__](event, self) - # set waiter - waiter = asyncio.Future[WebChatMessage]() - use_session.resp_waiters[message_id] = waiter - waiter.add_done_callback(lambda future: use_session.resp_waiters.pop(message_id)) + if is_stream: + queue = use_session.resp_queues[message_id] + msg_id = len(use_session.get_message_list(pipeline_uuid)) + 1 + while True: + resp_message = await queue.get() + resp_message.id = msg_id + if resp_message.is_final: + resp_message.id = msg_id + use_session.get_message_list(pipeline_uuid).append(resp_message) + yield resp_message.model_dump() + break + yield resp_message.model_dump() + use_session.resp_queues.pop(message_id) - resp_message = await waiter + else: # non-stream + # set waiter + # waiter = asyncio.Future[WebChatMessage]() + # use_session.resp_waiters[message_id] = waiter + # # waiter.add_done_callback(lambda future: use_session.resp_waiters.pop(message_id)) + # + # resp_message = await waiter + # + # resp_message.id = len(use_session.get_message_list(pipeline_uuid)) + 1 + # + # use_session.get_message_list(pipeline_uuid).append(resp_message) + # + # yield resp_message.model_dump() + msg_id = len(use_session.get_message_list(pipeline_uuid)) + 1 - resp_message.id = len(use_session.get_message_list(pipeline_uuid)) + 1 + queue = use_session.resp_queues[message_id] + resp_message = await queue.get() + use_session.get_message_list(pipeline_uuid).append(resp_message) + resp_message.id = msg_id + resp_message.is_final = True - use_session.get_message_list(pipeline_uuid).append(resp_message) - - return resp_message.model_dump() + yield resp_message.model_dump() def get_webchat_messages(self, pipeline_uuid: str, session_type: str) -> list[dict]: """获取调试消息历史""" diff --git a/pkg/platform/sources/webchat.yaml b/pkg/platform/sources/webchat.yaml index 4e8cc38e..748dfc8c 100644 --- a/pkg/platform/sources/webchat.yaml +++ b/pkg/platform/sources/webchat.yaml @@ -9,7 +9,8 @@ metadata: en_US: "WebChat adapter for pipeline debugging" zh_Hans: "用于流水线调试的网页聊天适配器" icon: "" -spec: {} +spec: + config: [] execution: python: path: "webchat.py" diff --git a/pkg/platform/sources/wechatpad.py b/pkg/platform/sources/wechatpad.py index ee9826e6..819ae400 100644 --- a/pkg/platform/sources/wechatpad.py +++ b/pkg/platform/sources/wechatpad.py @@ -240,8 +240,8 @@ class WeChatPadMessageConverter(adapter.MessageConverter): # self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode')) appmsg_data = xml_data.find('.//appmsg') quote_data = '' # 引用原文 - quote_id = None # 引用消息的原发送者 - tousername = None # 接收方: 所属微信的wxid + # quote_id = None # 引用消息的原发送者 + # tousername = None # 接收方: 所属微信的wxid user_data = '' # 用户消息 sender_id = xml_data.findtext('.//fromusername') # 发送方:单聊用户/群member @@ -249,13 +249,10 @@ class WeChatPadMessageConverter(adapter.MessageConverter): if appmsg_data: user_data = appmsg_data.findtext('.//title') or '' quote_data = appmsg_data.find('.//refermsg').findtext('.//content') - quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') + # quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg_data, encoding='unicode'))) - if message: - tousername = message['to_user_name']['str'] - - _ = quote_id - _ = tousername + # if message: + # tousername = message['to_user_name']['str'] if quote_data: quote_data_message_list = platform_message.MessageChain() diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index 94b812d9..4c4a65c1 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -125,6 +125,95 @@ class Message(pydantic.BaseModel): return platform_message.MessageChain(mc) +class MessageChunk(pydantic.BaseModel): + """消息""" + + resp_message_id: typing.Optional[str] = None + """消息id""" + + role: str # user, system, assistant, tool, command, plugin + """消息的角色""" + + name: typing.Optional[str] = None + """名称,仅函数调用返回时设置""" + + all_content: typing.Optional[str] = None + """所有内容""" + + content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None + """内容""" + + tool_calls: typing.Optional[list[ToolCall]] = None + """工具调用""" + + tool_call_id: typing.Optional[str] = None + + is_final: bool = False + """是否是结束""" + + msg_sequence: int = 0 + """消息迭代次数""" + + def readable_str(self) -> str: + if self.content is not None: + return str(self.role) + ': ' + str(self.get_content_platform_message_chain()) + elif self.tool_calls is not None: + return f'调用工具: {self.tool_calls[0].id}' + else: + return '未知消息' + + def get_content_platform_message_chain(self, prefix_text: str = '') -> platform_message.MessageChain | None: + """将内容转换为平台消息 MessageChain 对象 + + Args: + prefix_text (str): 首个文字组件的前缀文本 + """ + + if self.content is None: + return None + elif isinstance(self.content, str): + return platform_message.MessageChain([platform_message.Plain(prefix_text + self.content)]) + elif isinstance(self.content, list): + mc = [] + for ce in self.content: + if ce.type == 'text': + mc.append(platform_message.Plain(ce.text)) + elif ce.type == 'image_url': + if ce.image_url.url.startswith('http'): + mc.append(platform_message.Image(url=ce.image_url.url)) + else: # base64 + b64_str = ce.image_url.url + + if b64_str.startswith('data:'): + b64_str = b64_str.split(',')[1] + + mc.append(platform_message.Image(base64=b64_str)) + + # 找第一个文字组件 + if prefix_text: + for i, c in enumerate(mc): + if isinstance(c, platform_message.Plain): + mc[i] = platform_message.Plain(prefix_text + c.text) + break + else: + mc.insert(0, platform_message.Plain(prefix_text)) + + return platform_message.MessageChain(mc) + + +class ToolCallChunk(pydantic.BaseModel): + """工具调用""" + + id: str + """工具调用ID""" + + type: str + """工具调用类型""" + + function: FunctionCall + """函数调用""" + + class Prompt(pydantic.BaseModel): """供AI使用的Prompt""" diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 17697cdb..6af8ba70 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -84,6 +84,7 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta): messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, ) -> llm_entities.Message: """调用API @@ -92,12 +93,36 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta): messages (typing.List[llm_entities.Message]): 消息对象列表 funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None. extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. + remove_think (bool, optional): 是否移思考中的消息. Defaults to False. Returns: llm_entities.Message: 返回消息对象 """ pass + async def invoke_llm_stream( + self, + query: core_entities.Query, + model: RuntimeLLMModel, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, + ) -> llm_entities.MessageChunk: + """调用API + + Args: + model (RuntimeLLMModel): 使用的模型信息 + messages (typing.List[llm_entities.Message]): 消息对象列表 + funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None. + extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. + remove_think (bool, optional): 是否移除思考中的消息. Defaults to False. + + Returns: + typing.AsyncGenerator[llm_entities.MessageChunk]: 返回消息对象 + """ + pass + async def invoke_embedding( self, model: RuntimeEmbeddingModel, diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/pkg/provider/modelmgr/requesters/anthropicmsgs.py index b195ae51..cb0c7ce1 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.py @@ -21,7 +21,7 @@ class AnthropicMessages(requester.ProviderAPIRequester): client: anthropic.AsyncAnthropic default_config: dict[str, typing.Any] = { - 'base_url': 'https://api.anthropic.com/v1', + 'base_url': 'https://api.anthropic.com', 'timeout': 120, } @@ -44,6 +44,7 @@ class AnthropicMessages(requester.ProviderAPIRequester): self.client = anthropic.AsyncAnthropic( api_key='', http_client=httpx_client, + base_url=self.requester_cfg['base_url'], ) async def invoke_llm( @@ -53,6 +54,7 @@ class AnthropicMessages(requester.ProviderAPIRequester): messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, ) -> llm_entities.Message: self.client.api_key = model.token_mgr.get_token() @@ -89,7 +91,8 @@ class AnthropicMessages(requester.ProviderAPIRequester): { 'type': 'tool_result', 'tool_use_id': tool_call_id, - 'content': m.content, + 'is_error': False, + 'content': [{'type': 'text', 'text': m.content}], } ], } @@ -133,6 +136,9 @@ class AnthropicMessages(requester.ProviderAPIRequester): args['messages'] = req_messages + if 'thinking' in args: + args['thinking'] = {'type': 'enabled', 'budget_tokens': 10000} + if funcs: tools = await self.ap.tool_mgr.generate_tools_for_anthropic(funcs) @@ -140,19 +146,17 @@ class AnthropicMessages(requester.ProviderAPIRequester): args['tools'] = tools try: - # print(json.dumps(args, indent=4, ensure_ascii=False)) resp = await self.client.messages.create(**args) args = { 'content': '', 'role': resp.role, } - assert type(resp) is anthropic.types.message.Message for block in resp.content: - if block.type == 'thinking': - args['content'] = '' + block.thinking + '\n' + args['content'] + if not remove_think and block.type == 'thinking': + args['content'] = '\n' + block.thinking + '\n\n' + args['content'] elif block.type == 'text': args['content'] += block.text elif block.type == 'tool_use': @@ -176,3 +180,191 @@ class AnthropicMessages(requester.ProviderAPIRequester): raise errors.RequesterError(f'模型无效: {e.message}') else: raise errors.RequesterError(f'请求地址无效: {e.message}') + + async def invoke_llm_stream( + self, + query: core_entities.Query, + model: requester.RuntimeLLMModel, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, + ) -> llm_entities.Message: + self.client.api_key = model.token_mgr.get_token() + + args = extra_args.copy() + args['model'] = model.model_entity.name + args['stream'] = True + + # 处理消息 + + # system + system_role_message = None + + for i, m in enumerate(messages): + if m.role == 'system': + system_role_message = m + + break + + if system_role_message: + messages.pop(i) + + if isinstance(system_role_message, llm_entities.Message) and isinstance(system_role_message.content, str): + args['system'] = system_role_message.content + + req_messages = [] + + for m in messages: + if m.role == 'tool': + tool_call_id = m.tool_call_id + + req_messages.append( + { + 'role': 'user', + 'content': [ + { + 'type': 'tool_result', + 'tool_use_id': tool_call_id, + 'is_error': False, # 暂时直接写false + 'content': [ + {'type': 'text', 'text': m.content} + ], # 这里要是list包裹,应该是多个返回的情况?type类型好像也可以填其他的,暂时只写text + } + ], + } + ) + + continue + + msg_dict = m.dict(exclude_none=True) + + if isinstance(m.content, str) and m.content.strip() != '': + msg_dict['content'] = [{'type': 'text', 'text': m.content}] + elif isinstance(m.content, list): + for i, ce in enumerate(m.content): + if ce.type == 'image_base64': + image_b64, image_format = await image.extract_b64_and_format(ce.image_base64) + + alter_image_ele = { + 'type': 'image', + 'source': { + 'type': 'base64', + 'media_type': f'image/{image_format}', + 'data': image_b64, + }, + } + msg_dict['content'][i] = alter_image_ele + if isinstance(msg_dict['content'], str) and msg_dict['content'] == '': + msg_dict['content'] = [] # 这里不知道为什么会莫名有个空导致content为字符 + if m.tool_calls: + for tool_call in m.tool_calls: + msg_dict['content'].append( + { + 'type': 'tool_use', + 'id': tool_call.id, + 'name': tool_call.function.name, + 'input': json.loads(tool_call.function.arguments), + } + ) + + del msg_dict['tool_calls'] + + req_messages.append(msg_dict) + if 'thinking' in args: + args['thinking'] = {'type': 'enabled', 'budget_tokens': 10000} + + args['messages'] = req_messages + + if funcs: + tools = await self.ap.tool_mgr.generate_tools_for_anthropic(funcs) + + if tools: + args['tools'] = tools + + try: + role = 'assistant' # 默认角色 + # chunk_idx = 0 + think_started = False + think_ended = False + finish_reason = False + content = '' + tool_name = '' + tool_id = '' + async for chunk in await self.client.messages.create(**args): + tool_call = {'id': None, 'function': {'name': None, 'arguments': None}, 'type': 'function'} + if isinstance( + chunk, anthropic.types.raw_content_block_start_event.RawContentBlockStartEvent + ): # 记录开始 + if chunk.content_block.type == 'tool_use': + if chunk.content_block.name is not None: + tool_name = chunk.content_block.name + if chunk.content_block.id is not None: + tool_id = chunk.content_block.id + + tool_call['function']['name'] = tool_name + tool_call['function']['arguments'] = '' + tool_call['id'] = tool_id + + if not remove_think: + if chunk.content_block.type == 'thinking' and not remove_think: + think_started = True + elif chunk.content_block.type == 'text' and chunk.index != 0 and not remove_think: + think_ended = True + continue + elif isinstance(chunk, anthropic.types.raw_content_block_delta_event.RawContentBlockDeltaEvent): + if chunk.delta.type == 'thinking_delta': + if think_started: + think_started = False + content = '\n' + chunk.delta.thinking + elif remove_think: + continue + else: + content = chunk.delta.thinking + elif chunk.delta.type == 'text_delta': + if think_ended: + think_ended = False + content = '\n\n' + chunk.delta.text + else: + content = chunk.delta.text + elif chunk.delta.type == 'input_json_delta': + tool_call['function']['arguments'] = chunk.delta.partial_json + tool_call['function']['name'] = tool_name + tool_call['id'] = tool_id + elif isinstance(chunk, anthropic.types.raw_content_block_stop_event.RawContentBlockStopEvent): + continue # 记录raw_content_block结束的 + + elif isinstance(chunk, anthropic.types.raw_message_delta_event.RawMessageDeltaEvent): + if chunk.delta.stop_reason == 'end_turn': + finish_reason = True + elif isinstance(chunk, anthropic.types.raw_message_stop_event.RawMessageStopEvent): + continue # 这个好像是完全结束 + else: + # print(chunk) + self.ap.logger.debug(f'anthropic chunk: {chunk}') + continue + + args = { + 'content': content, + 'role': role, + 'is_final': finish_reason, + 'tool_calls': None if tool_call['id'] is None else [tool_call], + } + # if chunk_idx == 0: + # chunk_idx += 1 + # continue + + # assert type(chunk) is anthropic.types.message.Chunk + + yield llm_entities.MessageChunk(**args) + + # return llm_entities.Message(**args) + except anthropic.AuthenticationError as e: + raise errors.RequesterError(f'api-key 无效: {e.message}') + except anthropic.BadRequestError as e: + raise errors.RequesterError(str(e.message)) + except anthropic.NotFoundError as e: + if 'model: ' in str(e): + raise errors.RequesterError(f'模型无效: {e.message}') + else: + raise errors.RequesterError(f'请求地址无效: {e.message}') diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml index 7dbcf3ed..e3f745fb 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml @@ -14,7 +14,7 @@ spec: zh_Hans: 基础 URL type: string required: true - default: "https://api.anthropic.com/v1" + default: "https://api.anthropic.com" - name: timeout label: en_US: Timeout diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index aaaf3751..7afda84f 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -38,9 +38,18 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): ) -> chat_completion.ChatCompletion: return await self.client.chat.completions.create(**args, extra_body=extra_body) + async def _req_stream( + self, + args: dict, + extra_body: dict = {}, + ): + async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body): + yield chunk + async def _make_msg( self, chat_completion: chat_completion.ChatCompletion, + remove_think: bool = False, ) -> llm_entities.Message: chatcmpl_message = chat_completion.choices[0].message.model_dump() @@ -48,16 +57,192 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None: chatcmpl_message['role'] = 'assistant' - reasoning_content = chatcmpl_message['reasoning_content'] if 'reasoning_content' in chatcmpl_message else None + # 处理思维链 + content = chatcmpl_message.get('content', '') + reasoning_content = chatcmpl_message.get('reasoning_content', None) - # deepseek的reasoner模型 - if reasoning_content is not None: - chatcmpl_message['content'] = '\n' + reasoning_content + '\n\n' + chatcmpl_message['content'] + processed_content, _ = await self._process_thinking_content( + content=content, reasoning_content=reasoning_content, remove_think=remove_think + ) + + chatcmpl_message['content'] = processed_content + + # 移除 reasoning_content 字段,避免传递给 Message + if 'reasoning_content' in chatcmpl_message: + del chatcmpl_message['reasoning_content'] message = llm_entities.Message(**chatcmpl_message) - return message + async def _process_thinking_content( + self, + content: str, + reasoning_content: str = None, + remove_think: bool = False, + ) -> tuple[str, str]: + """处理思维链内容 + + Args: + content: 原始内容 + reasoning_content: reasoning_content 字段内容 + remove_think: 是否移除思维链 + + Returns: + (处理后的内容, 提取的思维链内容) + """ + thinking_content = '' + + # 1. 从 reasoning_content 提取思维链 + if reasoning_content: + thinking_content = reasoning_content + + # 2. 从 content 中提取 标签内容 + if content and '' in content and '' in content: + import re + + think_pattern = r'(.*?)' + think_matches = re.findall(think_pattern, content, re.DOTALL) + if think_matches: + # 如果已有 reasoning_content,则追加 + if thinking_content: + thinking_content += '\n' + '\n'.join(think_matches) + else: + thinking_content = '\n'.join(think_matches) + # 移除 content 中的 标签 + content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip() + + # 3. 根据 remove_think 参数决定是否保留思维链 + if remove_think: + return content, '' + else: + # 如果有思维链内容,将其以 格式添加到 content 开头 + if thinking_content: + content = f'\n{thinking_content}\n\n{content}'.strip() + return content, thinking_content + + async def _closure_stream( + self, + query: core_entities.Query, + req_messages: list[dict], + use_model: requester.RuntimeLLMModel, + use_funcs: list[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, + ) -> llm_entities.MessageChunk: + self.client.api_key = use_model.token_mgr.get_token() + + args = {} + args['model'] = use_model.model_entity.name + + if use_funcs: + tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) + if tools: + args['tools'] = tools + + # 设置此次请求中的messages + messages = req_messages.copy() + + # 检查vision + for msg in messages: + if 'content' in msg and isinstance(msg['content'], list): + for me in msg['content']: + if me['type'] == 'image_base64': + me['image_url'] = {'url': me['image_base64']} + me['type'] = 'image_url' + del me['image_base64'] + + args['messages'] = messages + args['stream'] = True + + # 流式处理状态 + tool_calls_map: dict[str, llm_entities.ToolCall] = {} + chunk_idx = 0 + thinking_started = False + thinking_ended = False + role = 'assistant' # 默认角色 + tool_id = "" + tool_name = '' + # accumulated_reasoning = '' # 仅用于判断何时结束思维链 + + async for chunk in self._req_stream(args, extra_body=extra_args): + # 解析 chunk 数据 + + if hasattr(chunk, 'choices') and chunk.choices: + choice = chunk.choices[0] + delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {} + + finish_reason = getattr(choice, 'finish_reason', None) + else: + delta = {} + finish_reason = None + # 从第一个 chunk 获取 role,后续使用这个 role + if 'role' in delta and delta['role']: + role = delta['role'] + + # 获取增量内容 + delta_content = delta.get('content', '') + reasoning_content = delta.get('reasoning_content', '') + + # 处理 reasoning_content + if reasoning_content: + # accumulated_reasoning += reasoning_content + # 如果设置了 remove_think,跳过 reasoning_content + if remove_think: + chunk_idx += 1 + continue + + # 第一次出现 reasoning_content,添加 开始标签 + if not thinking_started: + thinking_started = True + delta_content = '\n' + reasoning_content + else: + # 继续输出 reasoning_content + delta_content = reasoning_content + elif thinking_started and not thinking_ended and delta_content: + # reasoning_content 结束,normal content 开始,添加 结束标签 + thinking_ended = True + delta_content = '\n\n' + delta_content + + # 处理 content 中已有的 标签(如果需要移除) + # if delta_content and remove_think and '' in delta_content: + # import re + # + # # 移除 标签及其内容 + # delta_content = re.sub(r'.*?', '', delta_content, flags=re.DOTALL) + + # 处理工具调用增量 + # delta_tool_calls = None + if delta.get('tool_calls'): + for tool_call in delta['tool_calls']: + if tool_call['id'] and tool_call['function']['name']: + tool_id = tool_call['id'] + tool_name = tool_call['function']['name'] + else: + tool_call['id'] = tool_id + tool_call['function']['name'] = tool_name + if tool_call['type'] is None: + tool_call['type'] = 'function' + + + + # 跳过空的第一个 chunk(只有 role 没有内容) + if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'): + chunk_idx += 1 + continue + # 构建 MessageChunk - 只包含增量内容 + chunk_data = { + 'role': role, + 'content': delta_content if delta_content else None, + 'tool_calls': delta.get('tool_calls'), + 'is_final': bool(finish_reason), + } + + # 移除 None 值 + chunk_data = {k: v for k, v in chunk_data.items() if v is not None} + + yield llm_entities.MessageChunk(**chunk_data) + chunk_idx += 1 + async def _closure( self, query: core_entities.Query, @@ -65,6 +250,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() @@ -92,10 +278,10 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): args['messages'] = messages # 发送请求 - resp = await self._req(args, extra_body=extra_args) + resp = await self._req(args, extra_body=extra_args) # 处理请求结果 - message = await self._make_msg(resp) + message = await self._make_msg(resp, remove_think) return message @@ -106,6 +292,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, ) -> llm_entities.Message: req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 for m in messages: @@ -119,13 +306,15 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): req_messages.append(msg_dict) try: - return await self._closure( + msg = await self._closure( query=query, req_messages=req_messages, use_model=model, use_funcs=funcs, extra_args=extra_args, + remove_think=remove_think, ) + return msg except asyncio.TimeoutError: raise errors.RequesterError('请求超时') except openai.BadRequestError as e: @@ -169,6 +358,45 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): raise errors.RequesterError('请求超时') except openai.BadRequestError as e: raise errors.RequesterError(f'请求参数错误: {e.message}') + + async def invoke_llm_stream( + self, + query: core_entities.Query, + model: requester.RuntimeLLMModel, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, + ) -> llm_entities.MessageChunk: + req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 + for m in messages: + msg_dict = m.dict(exclude_none=True) + content = msg_dict.get('content') + if isinstance(content, list): + # 检查 content 列表中是否每个部分都是文本 + if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): + # 将所有文本部分合并为一个字符串 + msg_dict['content'] = '\n'.join(part['text'] for part in content) + req_messages.append(msg_dict) + + try: + async for item in self._closure_stream( + query=query, + req_messages=req_messages, + use_model=model, + use_funcs=funcs, + extra_args=extra_args, + remove_think=remove_think, + ): + yield item + + except asyncio.TimeoutError: + raise errors.RequesterError('请求超时') + except openai.BadRequestError as e: + if 'context_length_exceeded' in e.message: + raise errors.RequesterError(f'上文过长,请重置会话: {e.message}') + else: + raise errors.RequesterError(f'请求参数错误: {e.message}') except openai.AuthenticationError as e: raise errors.RequesterError(f'无效的 api-key: {e.message}') except openai.NotFoundError as e: diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py index 6d664b01..4866caf4 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py @@ -24,6 +24,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() @@ -49,10 +50,11 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): # 发送请求 resp = await self._req(args, extra_body=extra_args) + # print(resp) + if resp is None: raise errors.RequesterError('接口返回为空,请确定模型提供商服务是否正常') - # 处理请求结果 - message = await self._make_msg(resp) + message = await self._make_msg(resp, remove_think) return message diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py index 3795ef99..f8cf15ca 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py @@ -3,14 +3,16 @@ from __future__ import annotations import typing -from . import chatcmpl +from . import ppiochatcmpl from .. import requester from ....core import entities as core_entities from ... import entities as llm_entities from ...tools import entities as tools_entities +import re +import openai.types.chat.chat_completion as chat_completion -class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): +class GiteeAIChatCompletions(ppiochatcmpl.PPIOChatCompletions): """Gitee AI ChatCompletions API 请求器""" default_config: dict[str, typing.Any] = { @@ -18,34 +20,3 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): 'timeout': 120, } - async def _closure( - self, - query: core_entities.Query, - req_messages: list[dict], - use_model: requester.RuntimeLLMModel, - use_funcs: list[tools_entities.LLMFunction] = None, - extra_args: dict[str, typing.Any] = {}, - ) -> llm_entities.Message: - self.client.api_key = use_model.token_mgr.get_token() - - args = {} - args['model'] = use_model.model_entity.name - - if use_funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) - - if tools: - args['tools'] = tools - - # gitee 不支持多模态,把content都转换成纯文字 - for m in req_messages: - if 'content' in m and isinstance(m['content'], list): - m['content'] = ' '.join([c['text'] for c in m['content']]) - - args['messages'] = req_messages - - resp = await self._req(args, extra_body=extra_args) - - message = await self._make_msg(resp) - - return message diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py index 4708f671..82d8df70 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import json import typing import openai @@ -34,9 +35,11 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): async def _req( self, + query: core_entities.Query, args: dict, extra_body: dict = {}, - ) -> chat_completion.ChatCompletion: + remove_think: bool = False, + ) -> list[dict[str, typing.Any]]: args['stream'] = True chunk = None @@ -47,78 +50,75 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): resp_gen: openai.AsyncStream = await self.client.chat.completions.create(**args, extra_body=extra_body) + chunk_idx = 0 + thinking_started = False + thinking_ended = False + tool_id = '' + tool_name = '' + message_delta = {} async for chunk in resp_gen: - # print(chunk) if not chunk or not chunk.id or not chunk.choices or not chunk.choices[0] or not chunk.choices[0].delta: continue - if chunk.choices[0].delta.content is not None: - pending_content += chunk.choices[0].delta.content + delta = chunk.choices[0].delta.model_dump() if hasattr(chunk.choices[0], 'delta') else {} + reasoning_content = delta.get('reasoning_content') + # 处理 reasoning_content + if reasoning_content: + # accumulated_reasoning += reasoning_content + # 如果设置了 remove_think,跳过 reasoning_content + if remove_think: + chunk_idx += 1 + continue - if chunk.choices[0].delta.tool_calls is not None: - for tool_call in chunk.choices[0].delta.tool_calls: - if tool_call.function.arguments is None: + # 第一次出现 reasoning_content,添加 开始标签 + if not thinking_started: + thinking_started = True + pending_content += '\n' + reasoning_content + else: + # 继续输出 reasoning_content + pending_content += reasoning_content + elif thinking_started and not thinking_ended and delta.get('content'): + # reasoning_content 结束,normal content 开始,添加 结束标签 + thinking_ended = True + pending_content += '\n\n' + delta.get('content') + + if delta.get('content') is not None: + pending_content += delta.get('content') + + if delta.get('tool_calls') is not None: + for tool_call in delta.get('tool_calls'): + if tool_call['id'] != '': + tool_id = tool_call['id'] + if tool_call['function']['name'] is not None: + tool_name = tool_call['function']['name'] + if tool_call['function']['arguments'] is None: continue + tool_call['id'] = tool_id + tool_call['name'] = tool_name for tc in tool_calls: - if tc.index == tool_call.index: - tc.function.arguments += tool_call.function.arguments + if tc['index'] == tool_call['index']: + tc['function']['arguments'] += tool_call['function']['arguments'] break else: tool_calls.append(tool_call) if chunk.choices[0].finish_reason is not None: break + message_delta['content'] = pending_content + message_delta['role'] = 'assistant' - real_tool_calls = [] - - for tc in tool_calls: - function = chat_completion_message_tool_call.Function( - name=tc.function.name, arguments=tc.function.arguments - ) - real_tool_calls.append( - chat_completion_message_tool_call.ChatCompletionMessageToolCall( - id=tc.id, function=function, type='function' - ) - ) - - return ( - chat_completion.ChatCompletion( - id=chunk.id, - object='chat.completion', - created=chunk.created, - choices=[ - chat_completion.Choice( - index=0, - message=chat_completion.ChatCompletionMessage( - role='assistant', - content=pending_content, - tool_calls=real_tool_calls if len(real_tool_calls) > 0 else None, - ), - finish_reason=chunk.choices[0].finish_reason - if hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason is not None - else 'stop', - logprobs=chunk.choices[0].logprobs, - ) - ], - model=chunk.model, - service_tier=chunk.service_tier if hasattr(chunk, 'service_tier') else None, - system_fingerprint=chunk.system_fingerprint if hasattr(chunk, 'system_fingerprint') else None, - usage=chunk.usage if hasattr(chunk, 'usage') else None, - ) - if chunk - else None - ) + message_delta['tool_calls'] = tool_calls if tool_calls else None + return [message_delta] async def _make_msg( self, - chat_completion: chat_completion.ChatCompletion, + chat_completion: list[dict[str, typing.Any]], ) -> llm_entities.Message: - chatcmpl_message = chat_completion.choices[0].message.dict() + chatcmpl_message = chat_completion[0] # 确保 role 字段存在且不为 None if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None: chatcmpl_message['role'] = 'assistant' - message = llm_entities.Message(**chatcmpl_message) return message @@ -130,6 +130,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, + remove_think:bool = False, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() @@ -157,13 +158,146 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): args['messages'] = messages # 发送请求 - resp = await self._req(args, extra_body=extra_args) + resp = await self._req(query, args, extra_body=extra_args, remove_think=remove_think) # 处理请求结果 message = await self._make_msg(resp) return message + async def _req_stream( + self, + args: dict, + extra_body: dict = {}, + ) -> chat_completion.ChatCompletion: + async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body): + yield chunk + + + async def _closure_stream( + self, + query: core_entities.Query, + req_messages: list[dict], + use_model: requester.RuntimeLLMModel, + use_funcs: list[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, + ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: + self.client.api_key = use_model.token_mgr.get_token() + + args = {} + args['model'] = use_model.model_entity.name + + if use_funcs: + tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) + + if tools: + args['tools'] = tools + + # 设置此次请求中的messages + messages = req_messages.copy() + + # 检查vision + for msg in messages: + if 'content' in msg and isinstance(msg['content'], list): + for me in msg['content']: + if me['type'] == 'image_base64': + me['image_url'] = {'url': me['image_base64']} + me['type'] = 'image_url' + del me['image_base64'] + + args['messages'] = messages + args['stream'] = True + + + # 流式处理状态 + tool_calls_map: dict[str, llm_entities.ToolCall] = {} + chunk_idx = 0 + thinking_started = False + thinking_ended = False + role = 'assistant' # 默认角色 + # accumulated_reasoning = '' # 仅用于判断何时结束思维链 + + async for chunk in self._req_stream(args, extra_body=extra_args): + # 解析 chunk 数据 + if hasattr(chunk, 'choices') and chunk.choices: + choice = chunk.choices[0] + delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {} + finish_reason = getattr(choice, 'finish_reason', None) + else: + delta = {} + finish_reason = None + + # 从第一个 chunk 获取 role,后续使用这个 role + if 'role' in delta and delta['role']: + role = delta['role'] + + # 获取增量内容 + delta_content = delta.get('content', '') + reasoning_content = delta.get('reasoning_content', '') + + # 处理 reasoning_content + if reasoning_content: + # accumulated_reasoning += reasoning_content + # 如果设置了 remove_think,跳过 reasoning_content + if remove_think: + chunk_idx += 1 + continue + + # 第一次出现 reasoning_content,添加 开始标签 + if not thinking_started: + thinking_started = True + delta_content = '\n' + reasoning_content + else: + # 继续输出 reasoning_content + delta_content = reasoning_content + elif thinking_started and not thinking_ended and delta_content: + # reasoning_content 结束,normal content 开始,添加 结束标签 + thinking_ended = True + delta_content = '\n\n' + delta_content + + # 处理 content 中已有的 标签(如果需要移除) + # if delta_content and remove_think and '' in delta_content: + # import re + # + # # 移除 标签及其内容 + # delta_content = re.sub(r'.*?', '', delta_content, flags=re.DOTALL) + + # 处理工具调用增量 + if delta.get('tool_calls'): + for tool_call in delta['tool_calls']: + if tool_call['id'] != '': + tool_id = tool_call['id'] + if tool_call['function']['name'] is not None: + tool_name = tool_call['function']['name'] + + if tool_call['type'] is None: + tool_call['type'] = 'function' + tool_call['id'] = tool_id + tool_call['function']['name'] = tool_name + tool_call['function']['arguments'] = "" if tool_call['function']['arguments'] is None else tool_call['function']['arguments'] + + + # 跳过空的第一个 chunk(只有 role 没有内容) + if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'): + chunk_idx += 1 + continue + + # 构建 MessageChunk - 只包含增量内容 + chunk_data = { + 'role': role, + 'content': delta_content if delta_content else None, + 'tool_calls': delta.get('tool_calls'), + 'is_final': bool(finish_reason), + } + + # 移除 None 值 + chunk_data = {k: v for k, v in chunk_data.items() if v is not None} + + yield llm_entities.MessageChunk(**chunk_data) + chunk_idx += 1 + # return + async def invoke_llm( self, query: core_entities.Query, @@ -171,6 +305,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, ) -> llm_entities.Message: req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 for m in messages: @@ -185,7 +320,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): try: return await self._closure( - query=query, req_messages=req_messages, use_model=model, use_funcs=funcs, extra_args=extra_args + query=query, req_messages=req_messages, use_model=model, use_funcs=funcs, extra_args=extra_args, remove_think=remove_think ) except asyncio.TimeoutError: raise errors.RequesterError('请求超时') @@ -202,3 +337,50 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') + + async def invoke_llm_stream( + self, + query: core_entities.Query, + model: requester.RuntimeLLMModel, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, + ) -> llm_entities.MessageChunk: + req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 + for m in messages: + msg_dict = m.dict(exclude_none=True) + content = msg_dict.get('content') + if isinstance(content, list): + # 检查 content 列表中是否每个部分都是文本 + if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): + # 将所有文本部分合并为一个字符串 + msg_dict['content'] = '\n'.join(part['text'] for part in content) + req_messages.append(msg_dict) + + try: + async for item in self._closure_stream( + query=query, + req_messages=req_messages, + use_model=model, + use_funcs=funcs, + extra_args=extra_args, + remove_think=remove_think, + ): + yield item + + except asyncio.TimeoutError: + raise errors.RequesterError('请求超时') + except openai.BadRequestError as e: + if 'context_length_exceeded' in e.message: + raise errors.RequesterError(f'上文过长,请重置会话: {e.message}') + else: + raise errors.RequesterError(f'请求参数错误: {e.message}') + except openai.AuthenticationError as e: + raise errors.RequesterError(f'无效的 api-key: {e.message}') + except openai.NotFoundError as e: + raise errors.RequesterError(f'请求路径错误: {e.message}') + except openai.RateLimitError as e: + raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') + except openai.APIError as e: + raise errors.RequesterError(f'请求错误: {e.message}') diff --git a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py index f3621a09..494b2b0f 100644 --- a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py @@ -25,6 +25,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() @@ -54,6 +55,6 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): resp = await self._req(args, extra_body=extra_args) # 处理请求结果 - message = await self._make_msg(resp) + message = await self._make_msg(resp, remove_think) return message diff --git a/pkg/provider/modelmgr/requesters/ollamachat.py b/pkg/provider/modelmgr/requesters/ollamachat.py index 9e6f5a77..42203650 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/pkg/provider/modelmgr/requesters/ollamachat.py @@ -44,6 +44,7 @@ class OllamaChatCompletions(requester.ProviderAPIRequester): use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, ) -> llm_entities.Message: args = extra_args.copy() args['model'] = use_model.model_entity.name @@ -110,6 +111,7 @@ class OllamaChatCompletions(requester.ProviderAPIRequester): messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, ) -> llm_entities.Message: req_messages: list = [] for m in messages: @@ -126,6 +128,7 @@ class OllamaChatCompletions(requester.ProviderAPIRequester): use_model=model, use_funcs=funcs, extra_args=extra_args, + remove_think=remove_think, ) except asyncio.TimeoutError: raise errors.RequesterError('请求超时') diff --git a/pkg/provider/modelmgr/requesters/ppiochatcmpl.py b/pkg/provider/modelmgr/requesters/ppiochatcmpl.py index 7e78ddb8..4af1cde0 100644 --- a/pkg/provider/modelmgr/requesters/ppiochatcmpl.py +++ b/pkg/provider/modelmgr/requesters/ppiochatcmpl.py @@ -4,6 +4,12 @@ import openai import typing from . import chatcmpl +import openai.types.chat.chat_completion as chat_completion +from .. import requester +from ....core import entities as core_entities +from ... import entities as llm_entities +from ...tools import entities as tools_entities +import re class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): @@ -15,3 +21,193 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): 'base_url': 'https://api.ppinfra.com/v3/openai', 'timeout': 120, } + + is_think: bool = False + + async def _make_msg( + self, + chat_completion: chat_completion.ChatCompletion, + remove_think: bool, + ) -> llm_entities.Message: + chatcmpl_message = chat_completion.choices[0].message.model_dump() + # print(chatcmpl_message.keys(), chatcmpl_message.values()) + + # 确保 role 字段存在且不为 None + if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None: + chatcmpl_message['role'] = 'assistant' + + reasoning_content = chatcmpl_message['reasoning_content'] if 'reasoning_content' in chatcmpl_message else None + + # deepseek的reasoner模型 + chatcmpl_message["content"] = await self._process_thinking_content( + chatcmpl_message['content'],reasoning_content,remove_think) + + # 移除 reasoning_content 字段,避免传递给 Message + if 'reasoning_content' in chatcmpl_message: + del chatcmpl_message['reasoning_content'] + + + message = llm_entities.Message(**chatcmpl_message) + + return message + + async def _process_thinking_content( + self, + content: str, + reasoning_content: str = None, + remove_think: bool = False, + ) -> tuple[str, str]: + """处理思维链内容 + + Args: + content: 原始内容 + reasoning_content: reasoning_content 字段内容 + remove_think: 是否移除思维链 + + Returns: + 处理后的内容 + """ + if remove_think: + content = re.sub( + r'.*?', '', content, flags=re.DOTALL + ) + else: + if reasoning_content is not None: + content = ( + '\n' + reasoning_content + '\n\n' + content + ) + return content + + async def _make_msg_chunk( + self, + delta: dict[str, typing.Any], + idx: int, + ) -> llm_entities.MessageChunk: + # 处理流式chunk和完整响应的差异 + # print(chat_completion.choices[0]) + + # 确保 role 字段存在且不为 None + if 'role' not in delta or delta['role'] is None: + delta['role'] = 'assistant' + + reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None + + delta['content'] = '' if delta['content'] is None else delta['content'] + # print(reasoning_content) + + # deepseek的reasoner模型 + + if reasoning_content is not None: + delta['content'] += reasoning_content + + message = llm_entities.MessageChunk(**delta) + + return message + + async def _closure_stream( + self, + query: core_entities.Query, + req_messages: list[dict], + use_model: requester.RuntimeLLMModel, + use_funcs: list[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, + ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: + self.client.api_key = use_model.token_mgr.get_token() + + args = {} + args['model'] = use_model.model_entity.name + + if use_funcs: + tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) + + if tools: + args['tools'] = tools + + # 设置此次请求中的messages + messages = req_messages.copy() + + # 检查vision + for msg in messages: + if 'content' in msg and isinstance(msg['content'], list): + for me in msg['content']: + if me['type'] == 'image_base64': + me['image_url'] = {'url': me['image_base64']} + me['type'] = 'image_url' + del me['image_base64'] + + args['messages'] = messages + args['stream'] = True + + tool_calls_map: dict[str, llm_entities.ToolCall] = {} + chunk_idx = 0 + thinking_started = False + thinking_ended = False + role = 'assistant' # 默认角色 + async for chunk in self._req_stream(args, extra_body=extra_args): + # 解析 chunk 数据 + if hasattr(chunk, 'choices') and chunk.choices: + choice = chunk.choices[0] + delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {} + finish_reason = getattr(choice, 'finish_reason', None) + else: + delta = {} + finish_reason = None + + # 从第一个 chunk 获取 role,后续使用这个 role + if 'role' in delta and delta['role']: + role = delta['role'] + + # 获取增量内容 + delta_content = delta.get('content', '') + # reasoning_content = delta.get('reasoning_content', '') + + if remove_think: + if delta['content'] is not None: + if '' in delta['content'] and not thinking_started and not thinking_ended: + thinking_started = True + continue + elif delta['content'] == r'' and not thinking_ended: + thinking_ended = True + continue + elif thinking_ended and delta['content'] == '\n\n' and thinking_started: + thinking_started = False + continue + elif thinking_started and not thinking_ended: + continue + + + delta_tool_calls = None + if delta.get('tool_calls'): + for tool_call in delta['tool_calls']: + if tool_call['id'] and tool_call['function']['name']: + tool_id = tool_call['id'] + tool_name = tool_call['function']['name'] + + if tool_call['id'] is None: + tool_call['id'] = tool_id + if tool_call['function']['name'] is None: + tool_call['function']['name'] = tool_name + if tool_call['function']['arguments'] is None: + tool_call['function']['arguments'] = '' + if tool_call['type'] is None: + tool_call['type'] = 'function' + + # 跳过空的第一个 chunk(只有 role 没有内容) + if chunk_idx == 0 and not delta_content and not delta.get('tool_calls'): + chunk_idx += 1 + continue + + # 构建 MessageChunk - 只包含增量内容 + chunk_data = { + 'role': role, + 'content': delta_content if delta_content else None, + 'tool_calls': delta.get('tool_calls'), + 'is_final': bool(finish_reason), + } + + # 移除 None 值 + chunk_data = {k: v for k, v in chunk_data.items() if v is not None} + + yield llm_entities.MessageChunk(**chunk_data) + chunk_idx += 1 diff --git a/pkg/provider/runners/dashscopeapi.py b/pkg/provider/runners/dashscopeapi.py index 02cb0b51..737bc312 100644 --- a/pkg/provider/runners/dashscopeapi.py +++ b/pkg/provider/runners/dashscopeapi.py @@ -99,8 +99,14 @@ class DashScopeAPIRunner(runner.RequestRunner): plain_text = '' # 用户输入的纯文本信息 image_ids = [] # 用户输入的图片ID列表 (暂不支持) - plain_text, image_ids = await self._preprocess_user_message(query) + think_start = False + think_end = False + plain_text, image_ids = await self._preprocess_user_message(query) + has_thoughts = True # 获取思考过程 + remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think') + if remove_think: + has_thoughts = False # 发送对话请求 response = dashscope.Application.call( api_key=self.api_key, # 智能体应用的API Key @@ -109,43 +115,108 @@ class DashScopeAPIRunner(runner.RequestRunner): stream=True, # 流式输出 incremental_output=True, # 增量输出,使用流式输出需要开启增量输出 session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话 + has_thoughts=has_thoughts, # rag_options={ # 主要用于文件交互,暂不支持 # "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个 # } ) + idx_chunk = 0 + try: + is_stream = await query.adapter.is_stream_output_supported() - for chunk in response: - if chunk.get('status_code') != 200: - raise DashscopeAPIError( - f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} ' - ) - if not chunk: - continue + except AttributeError: + is_stream = False + if is_stream: + for chunk in response: + if chunk.get('status_code') != 200: + raise DashscopeAPIError( + f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} ' + ) + if not chunk: + continue + idx_chunk += 1 + # 获取流式传输的output + stream_output = chunk.get('output', {}) + stream_think = stream_output.get('thoughts', []) + if stream_think[0].get('thought'): + if not think_start: + think_start = True + pending_content += f"\n{stream_think[0].get('thought')}" + else: + # 继续输出 reasoning_content + pending_content += stream_think[0].get('thought') + elif stream_think[0].get('thought') == "" and not think_end: + think_end = True + pending_content += "\n\n" + if stream_output.get('text') is not None: + pending_content += stream_output.get('text') + # 是否是流式最后一个chunk + is_final = False if stream_output.get('finish_reason', False) == 'null' else True - # 获取流式传输的output - stream_output = chunk.get('output', {}) - if stream_output.get('text') is not None: - pending_content += stream_output.get('text') + # 获取模型传出的参考资料列表 + references_dict_list = stream_output.get('doc_references', []) - # 保存当前会话的session_id用于下次对话的语境 - query.session.using_conversation.uuid = stream_output.get('session_id') + # 从模型传出的参考资料信息中提取用于替换的字典 + if references_dict_list is not None: + for doc in references_dict_list: + if doc.get('index_id') is not None: + references_dict[doc.get('index_id')] = doc.get('doc_name') - # 获取模型传出的参考资料列表 - references_dict_list = stream_output.get('doc_references', []) + # 将参考资料替换到文本中 + pending_content = self._replace_references(pending_content, references_dict) - # 从模型传出的参考资料信息中提取用于替换的字典 - if references_dict_list is not None: - for doc in references_dict_list: - if doc.get('index_id') is not None: - references_dict[doc.get('index_id')] = doc.get('doc_name') + if idx_chunk % 8 == 0 or is_final: + yield llm_entities.MessageChunk( + role='assistant', + content=pending_content, + is_final=is_final, + ) + # 保存当前会话的session_id用于下次对话的语境 + query.session.using_conversation.uuid = stream_output.get('session_id') + else: + for chunk in response: + if chunk.get('status_code') != 200: + raise DashscopeAPIError( + f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} ' + ) + if not chunk: + continue + idx_chunk += 1 + # 获取流式传输的output + stream_output = chunk.get('output', {}) + stream_think = stream_output.get('thoughts', []) + if stream_think[0].get('thought'): + if not think_start: + think_start = True + pending_content += f"\n{stream_think[0].get('thought')}" + else: + # 继续输出 reasoning_content + pending_content += stream_think[0].get('thought') + elif stream_think[0].get('thought') == "" and not think_end: + think_end = True + pending_content += "\n\n" + if stream_output.get('text') is not None: + pending_content += stream_output.get('text') - # 将参考资料替换到文本中 - pending_content = self._replace_references(pending_content, references_dict) + # 保存当前会话的session_id用于下次对话的语境 + query.session.using_conversation.uuid = stream_output.get('session_id') - yield llm_entities.Message( - role='assistant', - content=pending_content, - ) + # 获取模型传出的参考资料列表 + references_dict_list = stream_output.get('doc_references', []) + + # 从模型传出的参考资料信息中提取用于替换的字典 + if references_dict_list is not None: + for doc in references_dict_list: + if doc.get('index_id') is not None: + references_dict[doc.get('index_id')] = doc.get('doc_name') + + # 将参考资料替换到文本中 + pending_content = self._replace_references(pending_content, references_dict) + + yield llm_entities.Message( + role='assistant', + content=pending_content, + ) async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """Dashscope 工作流对话请求""" @@ -171,52 +242,108 @@ class DashScopeAPIRunner(runner.RequestRunner): incremental_output=True, # 增量输出,使用流式输出需要开启增量输出 session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话 biz_params=biz_params, # 工作流应用的自定义输入参数传递 + flow_stream_mode="message_format" # 消息模式,输出/结束节点的流式结果 # rag_options={ # 主要用于文件交互,暂不支持 # "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个 # } ) # 处理API返回的流式输出 - for chunk in response: - if chunk.get('status_code') != 200: - raise DashscopeAPIError( - f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} ' - ) - if not chunk: - continue + try: + is_stream = await query.adapter.is_stream_output_supported() - # 获取流式传输的output - stream_output = chunk.get('output', {}) - if stream_output.get('text') is not None: - pending_content += stream_output.get('text') + except AttributeError: + is_stream = False + idx_chunk = 0 + if is_stream: + for chunk in response: + if chunk.get('status_code') != 200: + raise DashscopeAPIError( + f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} ' + ) + if not chunk: + continue + idx_chunk += 1 + # 获取流式传输的output + stream_output = chunk.get('output', {}) + if stream_output.get('workflow_message') is not None: + pending_content += stream_output.get('workflow_message').get('message').get('content') + # if stream_output.get('text') is not None: + # pending_content += stream_output.get('text') - # 保存当前会话的session_id用于下次对话的语境 - query.session.using_conversation.uuid = stream_output.get('session_id') + is_final = False if stream_output.get('finish_reason', False) == 'null' else True - # 获取模型传出的参考资料列表 - references_dict_list = stream_output.get('doc_references', []) + # 获取模型传出的参考资料列表 + references_dict_list = stream_output.get('doc_references', []) - # 从模型传出的参考资料信息中提取用于替换的字典 - if references_dict_list is not None: - for doc in references_dict_list: - if doc.get('index_id') is not None: - references_dict[doc.get('index_id')] = doc.get('doc_name') + # 从模型传出的参考资料信息中提取用于替换的字典 + if references_dict_list is not None: + for doc in references_dict_list: + if doc.get('index_id') is not None: + references_dict[doc.get('index_id')] = doc.get('doc_name') - # 将参考资料替换到文本中 - pending_content = self._replace_references(pending_content, references_dict) + # 将参考资料替换到文本中 + pending_content = self._replace_references(pending_content, references_dict) + if idx_chunk % 8 == 0 or is_final: + yield llm_entities.MessageChunk( + role='assistant', + content=pending_content, + is_final=is_final, + ) - yield llm_entities.Message( - role='assistant', - content=pending_content, - ) + # 保存当前会话的session_id用于下次对话的语境 + query.session.using_conversation.uuid = stream_output.get('session_id') + + else: + for chunk in response: + if chunk.get('status_code') != 200: + raise DashscopeAPIError( + f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} ' + ) + if not chunk: + continue + + # 获取流式传输的output + stream_output = chunk.get('output', {}) + if stream_output.get('text') is not None: + pending_content += stream_output.get('text') + + is_final = False if stream_output.get('finish_reason', False) == 'null' else True + + # 保存当前会话的session_id用于下次对话的语境 + query.session.using_conversation.uuid = stream_output.get('session_id') + + # 获取模型传出的参考资料列表 + references_dict_list = stream_output.get('doc_references', []) + + # 从模型传出的参考资料信息中提取用于替换的字典 + if references_dict_list is not None: + for doc in references_dict_list: + if doc.get('index_id') is not None: + references_dict[doc.get('index_id')] = doc.get('doc_name') + + # 将参考资料替换到文本中 + pending_content = self._replace_references(pending_content, references_dict) + + yield llm_entities.Message( + role='assistant', + content=pending_content, + ) async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行""" + msg_seq = 0 if self.app_type == 'agent': async for msg in self._agent_messages(query): + if isinstance(msg, llm_entities.MessageChunk): + msg_seq += 1 + msg.msg_sequence = msg_seq yield msg elif self.app_type == 'workflow': async for msg in self._workflow_messages(query): + if isinstance(msg, llm_entities.MessageChunk): + msg_seq += 1 + msg.msg_sequence = msg_seq yield msg else: raise DashscopeAPIError(f'不支持的 Dashscope 应用类型: {self.app_type}') diff --git a/pkg/provider/runners/difysvapi.py b/pkg/provider/runners/difysvapi.py index b2542491..b527a3bf 100644 --- a/pkg/provider/runners/difysvapi.py +++ b/pkg/provider/runners/difysvapi.py @@ -62,6 +62,39 @@ class DifyServiceAPIRunner(runner.RequestRunner): content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL) return f'{thinking_text.group(1)}\n{content_text}' + def _process_thinking_content( + self, + content: str, + ) -> tuple[str, str]: + """处理思维链内容 + + Args: + content: 原始内容 + Returns: + (处理后的内容, 提取的思维链内容) + """ + remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think') + thinking_content = '' + # 从 content 中提取 标签内容 + if content and '' in content and '' in content: + import re + + think_pattern = r'(.*?)' + think_matches = re.findall(think_pattern, content, re.DOTALL) + if think_matches: + thinking_content = '\n'.join(think_matches) + # 移除 content 中的 标签 + content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip() + + # 3. 根据 remove_think 参数决定是否保留思维链 + if remove_think: + return content, '' + else: + # 如果有思维链内容,将其以 格式添加到 content 开头 + if thinking_content: + content = f'\n{thinking_content}\n\n{content}'.strip() + return content, thinking_content + async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]: """预处理用户消息,提取纯文本,并将图片上传到 Dify 服务 @@ -132,17 +165,20 @@ class DifyServiceAPIRunner(runner.RequestRunner): if mode == 'workflow': if chunk['event'] == 'node_finished': if chunk['data']['node_type'] == 'answer': + content, _ = self._process_thinking_content(chunk['data']['outputs']['answer']) + yield llm_entities.Message( role='assistant', - content=self._try_convert_thinking(chunk['data']['outputs']['answer']), + content=content, ) elif mode == 'basic': if chunk['event'] == 'message': basic_mode_pending_chunk += chunk['answer'] elif chunk['event'] == 'message_end': + content, _ = self._process_thinking_content(basic_mode_pending_chunk) yield llm_entities.Message( role='assistant', - content=self._try_convert_thinking(basic_mode_pending_chunk), + content=content, ) basic_mode_pending_chunk = '' @@ -193,14 +229,15 @@ class DifyServiceAPIRunner(runner.RequestRunner): if chunk['event'] in ignored_events: continue - if chunk['event'] == 'agent_message': + if chunk['event'] == 'agent_message' or chunk['event'] == 'message': pending_agent_message += chunk['answer'] else: if pending_agent_message.strip() != '': pending_agent_message = pending_agent_message.replace('Action:', '') + content, _ = self._process_thinking_content(pending_agent_message) yield llm_entities.Message( role='assistant', - content=self._try_convert_thinking(pending_agent_message), + content=content, ) pending_agent_message = '' @@ -308,26 +345,353 @@ class DifyServiceAPIRunner(runner.RequestRunner): elif chunk['event'] == 'workflow_finished': if chunk['data']['error']: raise errors.DifyAPIError(chunk['data']['error']) + content, _ = self._process_thinking_content(chunk['data']['outputs']['summary']) msg = llm_entities.Message( role='assistant', - content=chunk['data']['outputs']['summary'], + content=content, ) yield msg + + async def _chat_messages_chunk(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.MessageChunk, None]: + """调用聊天助手""" + cov_id = query.session.using_conversation.uuid or '' + query.variables['conversation_id'] = cov_id + + plain_text, image_ids = await self._preprocess_user_message(query) + + files = [ + { + 'type': 'image', + 'transfer_method': 'local_file', + 'upload_file_id': image_id, + } + for image_id in image_ids + ] + + mode = 'basic' # 标记是基础编排还是工作流编排 + + basic_mode_pending_chunk = '' + + inputs = {} + + inputs.update(query.variables) + message_idx = 0 + + chunk = None # 初始化chunk变量,防止在没有响应时引用错误 + + is_final = False + think_start = False + think_end = False + + remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think') + + async for chunk in self.dify_client.chat_messages( + inputs=inputs, + query=plain_text, + user=f'{query.session.launcher_type.value}_{query.session.launcher_id}', + conversation_id=cov_id, + files=files, + timeout=120, + ): + self.ap.logger.debug('dify-chat-chunk: ' + str(chunk)) + + # if chunk['event'] == 'workflow_started': + # mode = 'workflow' + # if mode == 'workflow': + # elif mode == 'basic': + # 因为都只是返回的 message也没有工具调用什么的,暂时不分类 + if chunk['event'] == 'message': + message_idx += 1 + if remove_think: + if '' in chunk['answer'] and not think_start: + think_start = True + continue + if '' in chunk['answer'] and not think_end: + import re + content = re.sub(r'^\n', '', chunk['answer']) + basic_mode_pending_chunk += content + think_end = True + elif think_end: + basic_mode_pending_chunk += chunk['answer'] + if think_start: + continue + + else: + basic_mode_pending_chunk += chunk['answer'] + + if chunk['event'] == 'message_end': + is_final = True + + if is_final or message_idx % 8 == 0: + # content, _ = self._process_thinking_content(basic_mode_pending_chunk) + yield llm_entities.MessageChunk( + role='assistant', + content=basic_mode_pending_chunk, + is_final=is_final, + ) + + + if chunk is None: + raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置') + + query.session.using_conversation.uuid = chunk['conversation_id'] + + + async def _agent_chat_messages_chunk( + self, query: core_entities.Query + ) -> typing.AsyncGenerator[llm_entities.MessageChunk, None]: + """调用聊天助手""" + cov_id = query.session.using_conversation.uuid or '' + query.variables['conversation_id'] = cov_id + + plain_text, image_ids = await self._preprocess_user_message(query) + + files = [ + { + 'type': 'image', + 'transfer_method': 'local_file', + 'upload_file_id': image_id, + } + for image_id in image_ids + ] + + ignored_events = [] + + inputs = {} + + inputs.update(query.variables) + + pending_agent_message = '' + + chunk = None # 初始化chunk变量,防止在没有响应时引用错误 + message_idx = 0 + is_final = False + think_start = False + think_end = False + + remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think') + + async for chunk in self.dify_client.chat_messages( + inputs=inputs, + query=plain_text, + user=f'{query.session.launcher_type.value}_{query.session.launcher_id}', + response_mode='streaming', + conversation_id=cov_id, + files=files, + timeout=120, + ): + self.ap.logger.debug('dify-agent-chunk: ' + str(chunk)) + + if chunk['event'] in ignored_events: + continue + + if chunk['event'] == 'agent_message': + message_idx += 1 + if remove_think: + if '' in chunk['answer'] and not think_start: + think_start = True + continue + if '' in chunk['answer'] and not think_end: + import re + content = re.sub(r'^\n', '', chunk['answer']) + pending_agent_message += content + think_end = True + elif think_end: + pending_agent_message += chunk['answer'] + if think_start: + continue + + else: + pending_agent_message += chunk['answer'] + elif chunk['event'] == 'message_end': + is_final = True + else: + + if chunk['event'] == 'agent_thought': + if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过 + continue + message_idx += 1 + if chunk['tool']: + msg = llm_entities.MessageChunk( + role='assistant', + tool_calls=[ + llm_entities.ToolCall( + id=chunk['id'], + type='function', + function=llm_entities.FunctionCall( + name=chunk['tool'], + arguments=json.dumps({}), + ), + ) + ], + ) + yield msg + if chunk['event'] == 'message_file': + message_idx += 1 + if chunk['type'] == 'image' and chunk['belongs_to'] == 'assistant': + base_url = self.dify_client.base_url + + if base_url.endswith('/v1'): + base_url = base_url[:-3] + + image_url = base_url + chunk['url'] + + yield llm_entities.MessageChunk( + role='assistant', + content=[llm_entities.ContentElement.from_image_url(image_url)], + is_final=is_final, + + ) + + if chunk['event'] == 'error': + raise errors.DifyAPIError('dify 服务错误: ' + chunk['message']) + if message_idx % 8 == 0 or is_final: + yield llm_entities.MessageChunk( + role='assistant', + content=pending_agent_message, + is_final=is_final, + ) + + if chunk is None: + raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置') + + query.session.using_conversation.uuid = chunk['conversation_id'] + + async def _workflow_messages_chunk(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.MessageChunk, None]: + """调用工作流""" + + if not query.session.using_conversation.uuid: + query.session.using_conversation.uuid = str(uuid.uuid4()) + + query.variables['conversation_id'] = query.session.using_conversation.uuid + + plain_text, image_ids = await self._preprocess_user_message(query) + + files = [ + { + 'type': 'image', + 'transfer_method': 'local_file', + 'upload_file_id': image_id, + } + for image_id in image_ids + ] + + ignored_events = ['workflow_started'] + + inputs = { # these variables are legacy variables, we need to keep them for compatibility + 'langbot_user_message_text': plain_text, + 'langbot_session_id': query.variables['session_id'], + 'langbot_conversation_id': query.variables['conversation_id'], + 'langbot_msg_create_time': query.variables['msg_create_time'], + } + + inputs.update(query.variables) + messsage_idx = 0 + is_final = False + think_start = False + think_end = False + workflow_contents = '' + + remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think') + async for chunk in self.dify_client.workflow_run( + inputs=inputs, + user=f'{query.session.launcher_type.value}_{query.session.launcher_id}', + files=files, + timeout=120, + ): + self.ap.logger.debug('dify-workflow-chunk: ' + str(chunk)) + if chunk['event'] in ignored_events: + continue + if chunk['event'] == 'workflow_finished': + is_final = True + if chunk['data']['error']: + raise errors.DifyAPIError(chunk['data']['error']) + + if chunk['event'] == 'text_chunk': + messsage_idx += 1 + if remove_think: + if '' in chunk['data']['text'] and not think_start: + think_start = True + continue + if '' in chunk['data']['text'] and not think_end: + import re + content = re.sub(r'^\n', '', chunk['data']['text']) + workflow_contents += content + think_end = True + elif think_end: + workflow_contents += chunk['data']['text'] + if think_start: + continue + + else: + workflow_contents += chunk['data']['text'] + + if chunk['event'] == 'node_started': + if chunk['data']['node_type'] == 'start' or chunk['data']['node_type'] == 'end': + continue + messsage_idx += 1 + msg = llm_entities.MessageChunk( + role='assistant', + content=None, + tool_calls=[ + llm_entities.ToolCall( + id=chunk['data']['node_id'], + type='function', + function=llm_entities.FunctionCall( + name=chunk['data']['title'], + arguments=json.dumps({}), + ), + ) + ], + ) + + yield msg + + + if messsage_idx % 8 == 0 or is_final: + yield llm_entities.MessageChunk( + role='assistant', + content=workflow_contents, + is_final=is_final, + ) + async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行请求""" - if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat': - async for msg in self._chat_messages(query): - yield msg - elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'agent': - async for msg in self._agent_chat_messages(query): - yield msg - elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'workflow': - async for msg in self._workflow_messages(query): - yield msg + if await query.adapter.is_stream_output_supported(): + msg_idx = 0 + if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat': + async for msg in self._chat_messages_chunk(query): + msg_idx += 1 + msg.msg_sequence = msg_idx + yield msg + elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'agent': + async for msg in self._agent_chat_messages_chunk(query): + msg_idx += 1 + msg.msg_sequence = msg_idx + yield msg + elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'workflow': + async for msg in self._workflow_messages_chunk(query): + msg_idx += 1 + msg.msg_sequence = msg_idx + yield msg + else: + raise errors.DifyAPIError( + f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}' + ) else: - raise errors.DifyAPIError( - f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}' - ) + if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat': + async for msg in self._chat_messages(query): + yield msg + elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'agent': + async for msg in self._agent_chat_messages(query): + yield msg + elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'workflow': + async for msg in self._workflow_messages(query): + yield msg + else: + raise errors.DifyAPIError( + f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}' + ) \ No newline at end of file diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index f7dfcb52..2500b363 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -27,7 +27,16 @@ Respond in the same language as the user's input. class LocalAgentRunner(runner.RequestRunner): """本地Agent请求运行器""" - async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + class ToolCallTracker: + """工具调用追踪器""" + + def __init__(self): + self.active_calls: dict[str, dict] = {} + self.completed_calls: list[llm_entities.ToolCall] = [] + + async def run( + self, query: core_entities.Query + ) -> typing.AsyncGenerator[llm_entities.Message | llm_entities.MessageChunk, None]: """运行请求""" pending_tool_calls = [] @@ -80,20 +89,92 @@ class LocalAgentRunner(runner.RequestRunner): req_messages = query.prompt.messages.copy() + query.messages.copy() + [user_message] - # 首次请求 - msg = await query.use_llm_model.requester.invoke_llm( - query, - query.use_llm_model, - req_messages, - query.use_funcs, - extra_args=query.use_llm_model.model_entity.extra_args, - ) + try: + is_stream = await query.adapter.is_stream_output_supported() + except AttributeError: + is_stream = False - yield msg + remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think') - pending_tool_calls = msg.tool_calls + if not is_stream: + # 非流式输出,直接请求 - req_messages.append(msg) + msg = await query.use_llm_model.requester.invoke_llm( + query, + query.use_llm_model, + req_messages, + query.use_funcs, + extra_args=query.use_llm_model.model_entity.extra_args, + remove_think=remove_think, + ) + yield msg + final_msg = msg + else: + # 流式输出,需要处理工具调用 + tool_calls_map: dict[str, llm_entities.ToolCall] = {} + msg_idx = 0 + accumulated_content = '' # 从开始累积的所有内容 + last_role = 'assistant' + msg_sequence = 1 + async for msg in query.use_llm_model.requester.invoke_llm_stream( + query, + query.use_llm_model, + req_messages, + query.use_funcs, + extra_args=query.use_llm_model.model_entity.extra_args, + remove_think=remove_think, + ): + msg_idx = msg_idx + 1 + + # 记录角色 + if msg.role: + last_role = msg.role + + # 累积内容 + if msg.content: + accumulated_content += msg.content + + # 处理工具调用 + if msg.tool_calls: + for tool_call in msg.tool_calls: + if tool_call.id not in tool_calls_map: + tool_calls_map[tool_call.id] = llm_entities.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=llm_entities.FunctionCall( + name=tool_call.function.name if tool_call.function else '', arguments='' + ), + ) + if tool_call.function and tool_call.function.arguments: + # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 + tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + # continue + # 每8个chunk或最后一个chunk时,输出所有累积的内容 + if msg_idx % 8 == 0 or msg.is_final: + msg_sequence += 1 + yield llm_entities.MessageChunk( + role=last_role, + content=accumulated_content, # 输出所有累积内容 + tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None, + is_final=msg.is_final, + msg_sequence=msg_sequence, + ) + + # 创建最终消息用于后续处理 + final_msg = llm_entities.MessageChunk( + role=last_role, + content=accumulated_content, + tool_calls=list(tool_calls_map.values()) if tool_calls_map else None, + msg_sequence=msg_sequence, + ) + + pending_tool_calls = final_msg.tool_calls + first_content = final_msg.content + if isinstance(final_msg, llm_entities.MessageChunk): + + first_end_sequence = final_msg.msg_sequence + + req_messages.append(final_msg) # 持续请求,只要还有待处理的工具调用就继续处理调用 while pending_tool_calls: @@ -104,12 +185,18 @@ class LocalAgentRunner(runner.RequestRunner): parameters = json.loads(func.arguments) func_ret = await self.ap.tool_mgr.execute_func_call(query, func.name, parameters) - - msg = llm_entities.Message( - role='tool', - content=json.dumps(func_ret, ensure_ascii=False), - tool_call_id=tool_call.id, - ) + if is_stream: + msg = llm_entities.MessageChunk( + role='tool', + content=json.dumps(func_ret, ensure_ascii=False), + tool_call_id=tool_call.id, + ) + else: + msg = llm_entities.Message( + role='tool', + content=json.dumps(func_ret, ensure_ascii=False), + tool_call_id=tool_call.id, + ) yield msg @@ -122,17 +209,82 @@ class LocalAgentRunner(runner.RequestRunner): req_messages.append(err_msg) - # 处理完所有调用,再次请求 - msg = await query.use_llm_model.requester.invoke_llm( - query, - query.use_llm_model, - req_messages, - query.use_funcs, - extra_args=query.use_llm_model.model_entity.extra_args, - ) + if is_stream: + tool_calls_map = {} + msg_idx = 0 + accumulated_content = '' # 从开始累积的所有内容 + last_role = 'assistant' + msg_sequence = first_end_sequence - yield msg + async for msg in query.use_llm_model.requester.invoke_llm_stream( + query, + query.use_llm_model, + req_messages, + query.use_funcs, + extra_args=query.use_llm_model.model_entity.extra_args, + remove_think=remove_think, + ): + msg_idx += 1 - pending_tool_calls = msg.tool_calls + # 记录角色 + if msg.role: + last_role = msg.role - req_messages.append(msg) + # 第一次请求工具调用时的内容 + if msg_idx == 1: + accumulated_content = first_content if first_content is not None else accumulated_content + + # 累积内容 + if msg.content: + accumulated_content += msg.content + + # 处理工具调用 + if msg.tool_calls: + for tool_call in msg.tool_calls: + if tool_call.id not in tool_calls_map: + tool_calls_map[tool_call.id] = llm_entities.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=llm_entities.FunctionCall( + name=tool_call.function.name if tool_call.function else '', arguments='' + ), + ) + if tool_call.function and tool_call.function.arguments: + # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 + tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + + # 每8个chunk或最后一个chunk时,输出所有累积的内容 + if msg_idx % 8 == 0 or msg.is_final: + msg_sequence += 1 + yield llm_entities.MessageChunk( + role=last_role, + content=accumulated_content, # 输出所有累积内容 + tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None, + is_final=msg.is_final, + msg_sequence=msg_sequence, + ) + + final_msg = llm_entities.MessageChunk( + role=last_role, + content=accumulated_content, + tool_calls=list(tool_calls_map.values()) if tool_calls_map else None, + msg_sequence=msg_sequence, + + ) + else: + # 处理完所有调用,再次请求 + msg = await query.use_llm_model.requester.invoke_llm( + query, + query.use_llm_model, + req_messages, + query.use_funcs, + extra_args=query.use_llm_model.model_entity.extra_args, + remove_think=remove_think, + ) + + yield msg + final_msg = msg + + pending_tool_calls = final_msg.tool_calls + + req_messages.append(final_msg) diff --git a/pkg/utils/constants.py b/pkg/utils/constants.py index e13958d9..28d6e3e5 100644 --- a/pkg/utils/constants.py +++ b/pkg/utils/constants.py @@ -1,6 +1,6 @@ semantic_version = 'v4.1.2' -required_database_version = 4 +required_database_version = 5 """Tag the version of the database schema, used to check if the database needs to be migrated""" debug_mode = False diff --git a/templates/default-pipeline-config.json b/templates/default-pipeline-config.json index d06e4661..855e2ac6 100644 --- a/templates/default-pipeline-config.json +++ b/templates/default-pipeline-config.json @@ -87,7 +87,8 @@ "hide-exception": true, "at-sender": true, "quote-origin": true, - "track-function-calls": false + "track-function-calls": false, + "remove-think": true } } } \ No newline at end of file diff --git a/templates/metadata/pipeline/ai.yaml b/templates/metadata/pipeline/ai.yaml index ffbefe63..63c56a8a 100644 --- a/templates/metadata/pipeline/ai.yaml +++ b/templates/metadata/pipeline/ai.yaml @@ -138,6 +138,8 @@ stages: label: en_US: Remove zh_Hans: 移除 + + - name: dashscope-app-api label: en_US: Aliyun Dashscope App API diff --git a/templates/metadata/pipeline/output.yaml b/templates/metadata/pipeline/output.yaml index 9fe0cd25..66bb312c 100644 --- a/templates/metadata/pipeline/output.yaml +++ b/templates/metadata/pipeline/output.yaml @@ -105,3 +105,13 @@ stages: type: boolean required: true default: false + - name: remove-think + label: + en_US: Remove CoT + zh_Hans: 删除思维链 + description: + en_US: If enabled, LangBot will remove the LLM thought content in response + zh_Hans: 如果启用,将自动删除大模型回复中的模型思考内容 + type: boolean + required: true + default: true diff --git a/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx index abc717c6..eee02b5a 100644 --- a/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx +++ b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx @@ -298,6 +298,18 @@ export default function EmbeddingForm({ function testEmbeddingModelInForm() { setModelTesting(true); + const extraArgsObj: Record = {}; + form + .getValues('extra_args') + ?.forEach((arg: { key: string; type: string; value: string }) => { + if (arg.type === 'number') { + extraArgsObj[arg.key] = Number(arg.value); + } else if (arg.type === 'boolean') { + extraArgsObj[arg.key] = arg.value === 'true'; + } else { + extraArgsObj[arg.key] = arg.value; + } + }); httpClient .testEmbeddingModel('_', { uuid: '', @@ -309,6 +321,7 @@ export default function EmbeddingForm({ timeout: 120, }, api_keys: [form.getValues('api_key')], + extra_args: extraArgsObj, }) .then((res) => { console.log(res); diff --git a/web/src/app/home/models/component/llm-form/LLMForm.tsx b/web/src/app/home/models/component/llm-form/LLMForm.tsx index 54e0b9ce..434e68a4 100644 --- a/web/src/app/home/models/component/llm-form/LLMForm.tsx +++ b/web/src/app/home/models/component/llm-form/LLMForm.tsx @@ -312,6 +312,18 @@ export default function LLMForm({ function testLLMModelInForm() { setModelTesting(true); + const extraArgsObj: Record = {}; + form + .getValues('extra_args') + ?.forEach((arg: { key: string; type: string; value: string }) => { + if (arg.type === 'number') { + extraArgsObj[arg.key] = Number(arg.value); + } else if (arg.type === 'boolean') { + extraArgsObj[arg.key] = arg.value === 'true'; + } else { + extraArgsObj[arg.key] = arg.value; + } + }); httpClient .testLLMModel('_', { uuid: '', @@ -324,7 +336,7 @@ export default function LLMForm({ }, api_keys: [form.getValues('api_key')], abilities: form.getValues('abilities'), - extra_args: form.getValues('extra_args'), + extra_args: extraArgsObj, }) .then((res) => { console.log(res); diff --git a/web/src/app/home/pipelines/components/debug-dialog/DebugDialog.tsx b/web/src/app/home/pipelines/components/debug-dialog/DebugDialog.tsx index 1d2c0840..71e5b748 100644 --- a/web/src/app/home/pipelines/components/debug-dialog/DebugDialog.tsx +++ b/web/src/app/home/pipelines/components/debug-dialog/DebugDialog.tsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect, useRef } from 'react'; +import React, { useState, useEffect, useRef, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { httpClient } from '@/app/infra/http/HttpClient'; import { DialogContent } from '@/components/ui/dialog'; @@ -10,6 +10,7 @@ import { cn } from '@/lib/utils'; import { Message } from '@/app/infra/entities/message'; import { toast } from 'sonner'; import AtBadge from './AtBadge'; +import { Switch } from '@/components/ui/switch'; interface MessageComponent { type: 'At' | 'Plain'; @@ -36,17 +37,44 @@ export default function DebugDialog({ const [showAtPopover, setShowAtPopover] = useState(false); const [hasAt, setHasAt] = useState(false); const [isHovering, setIsHovering] = useState(false); + const [isStreaming, setIsStreaming] = useState(true); const messagesEndRef = useRef(null); const inputRef = useRef(null); const popoverRef = useRef(null); - const scrollToBottom = () => { - messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' }); - }; + const scrollToBottom = useCallback(() => { + // 使用setTimeout确保在DOM更新后执行滚动 + setTimeout(() => { + const scrollArea = document.querySelector('.scroll-area') as HTMLElement; + if (scrollArea) { + scrollArea.scrollTo({ + top: scrollArea.scrollHeight, + behavior: 'smooth', + }); + } + // 同时确保messagesEndRef也滚动到视图 + messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' }); + }, 0); + }, []); + const loadMessages = useCallback( + async (pipelineId: string) => { + try { + const response = await httpClient.getWebChatHistoryMessages( + pipelineId, + sessionType, + ); + setMessages(response.messages); + } catch (error) { + console.error('Failed to load messages:', error); + } + }, + [sessionType], + ); + // 在useEffect中监听messages变化时滚动 useEffect(() => { scrollToBottom(); - }, [messages]); + }, [messages, scrollToBottom]); useEffect(() => { if (open) { @@ -59,7 +87,7 @@ export default function DebugDialog({ if (open) { loadMessages(selectedPipelineId); } - }, [sessionType, selectedPipelineId]); + }, [sessionType, selectedPipelineId, open, loadMessages]); useEffect(() => { const handleClickOutside = (event: MouseEvent) => { @@ -84,18 +112,6 @@ export default function DebugDialog({ } }, [showAtPopover]); - const loadMessages = async (pipelineId: string) => { - try { - const response = await httpClient.getWebChatHistoryMessages( - pipelineId, - sessionType, - ); - setMessages(response.messages); - } catch (error) { - console.error('Failed to load messages:', error); - } - }; - const handleInputChange = (e: React.ChangeEvent) => { const value = e.target.value; if (sessionType === 'group') { @@ -165,19 +181,87 @@ export default function DebugDialog({ timestamp: new Date().toISOString(), message_chain: messageChain, }; + // 根据isStreaming状态决定使用哪种传输方式 + if (isStreaming) { + // streaming + // 创建初始bot消息 + const placeholderRandomId = Math.floor(Math.random() * 1000000); + const botMessagePlaceholder: Message = { + id: placeholderRandomId, + role: 'assistant', + content: 'Generating...', + timestamp: new Date().toISOString(), + message_chain: [{ type: 'Plain', text: 'Generating...' }], + }; - setMessages((prevMessages) => [...prevMessages, userMessage]); - setInputValue(''); - setHasAt(false); + // 添加用户消息和初始bot消息到状态 - const response = await httpClient.sendWebChatMessage( - sessionType, - messageChain, - selectedPipelineId, - 120000, - ); + setMessages((prevMessages) => [ + ...prevMessages, + userMessage, + botMessagePlaceholder, + ]); + setInputValue(''); + setHasAt(false); + try { + await httpClient.sendStreamingWebChatMessage( + sessionType, + messageChain, + selectedPipelineId, + (data) => { + // 处理流式响应数据 + console.log('data', data); + if (data.message) { + // 更新完整内容 - setMessages((prevMessages) => [...prevMessages, response.message]); + setMessages((prevMessages) => { + const updatedMessages = [...prevMessages]; + const botMessageIndex = updatedMessages.findIndex( + (message) => message.id === placeholderRandomId, + ); + if (botMessageIndex !== -1) { + updatedMessages[botMessageIndex] = { + ...updatedMessages[botMessageIndex], + content: data.message.content, + message_chain: [ + { type: 'Plain', text: data.message.content }, + ], + }; + } + return updatedMessages; + }); + } + }, + () => {}, + (error) => { + // 处理错误 + console.error('Streaming error:', error); + if (sessionType === 'person') { + toast.error(t('pipelines.debugDialog.sendFailed')); + } + }, + ); + } catch (error) { + console.error('Failed to send streaming message:', error); + if (sessionType === 'person') { + toast.error(t('pipelines.debugDialog.sendFailed')); + } + } + } else { + // non-streaming + setMessages((prevMessages) => [...prevMessages, userMessage]); + setInputValue(''); + setHasAt(false); + + const response = await httpClient.sendWebChatMessage( + sessionType, + messageChain, + selectedPipelineId, + 180000, + ); + + setMessages((prevMessages) => [...prevMessages, response.message]); + } } catch ( // eslint-disable-next-line @typescript-eslint/no-explicit-any error: any @@ -306,6 +390,12 @@ export default function DebugDialog({
+
+ + {t('pipelines.debugDialog.streaming')} + + +
{hasAt && ( diff --git a/web/src/app/infra/http/HttpClient.ts b/web/src/app/infra/http/HttpClient.ts index 5e090e99..9f5967d0 100644 --- a/web/src/app/infra/http/HttpClient.ts +++ b/web/src/app/infra/http/HttpClient.ts @@ -372,6 +372,99 @@ class HttpClient { ); } + public async sendStreamingWebChatMessage( + sessionType: string, + messageChain: object[], + pipelineId: string, + onMessage: (data: ApiRespWebChatMessage) => void, + onComplete: () => void, + onError: (error: Error) => void, + ): Promise { + try { + // 构造完整的URL,处理相对路径的情况 + let url = `${this.baseURL}/api/v1/pipelines/${pipelineId}/chat/send`; + if (this.baseURL === '/') { + // 获取用户访问的完整URL + const baseURL = window.location.origin; + url = `${baseURL}/api/v1/pipelines/${pipelineId}/chat/send`; + } + + // 使用fetch发送流式请求,因为axios在浏览器环境中不直接支持流式响应 + const response = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${this.getSessionSync()}`, + }, + body: JSON.stringify({ + session_type: sessionType, + message: messageChain, + is_stream: true, + }), + }); + + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + if (!response.body) { + throw new Error('ReadableStream not supported'); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + + // 读取流式响应 + try { + while (true) { + const { done, value } = await reader.read(); + + if (done) { + onComplete(); + break; + } + + // 解码数据 + buffer += decoder.decode(value, { stream: true }); + + // 处理完整的JSON对象 + const lines = buffer.split('\n\n'); + buffer = lines.pop() || ''; + + for (const line of lines) { + if (line.startsWith('data:')) { + try { + const data = JSON.parse(line.slice(5)); + + if (data.type === 'end') { + // 流传输结束 + reader.cancel(); + onComplete(); + return; + } + if (data.type === 'start') { + console.log(data.type); + } + + if (data.message) { + // 处理消息数据 + onMessage(data); + } + } catch (error) { + console.error('Error parsing streaming data:', error); + } + } + } + } + } finally { + reader.releaseLock(); + } + } catch (error) { + onError(error as Error); + } + } + public getWebChatHistoryMessages( pipelineId: string, sessionType: string, diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index 70635bd4..5bc959b1 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -250,6 +250,7 @@ const enUS = { loadMessagesFailed: 'Failed to load messages', loadPipelinesFailed: 'Failed to load pipelines', atTips: 'Mention the bot', + streaming: 'Streaming', }, }, knowledge: { diff --git a/web/src/i18n/locales/ja-JP.ts b/web/src/i18n/locales/ja-JP.ts index 6e286727..43b78f2a 100644 --- a/web/src/i18n/locales/ja-JP.ts +++ b/web/src/i18n/locales/ja-JP.ts @@ -252,6 +252,7 @@ const jaJP = { loadMessagesFailed: 'メッセージの読み込みに失敗しました', loadPipelinesFailed: 'パイプラインの読み込みに失敗しました', atTips: 'ボットをメンション', + streaming: 'ストリーミング', }, }, knowledge: { diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index 256124dc..da3f50a2 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -244,6 +244,7 @@ const zhHans = { loadMessagesFailed: '加载消息失败', loadPipelinesFailed: '加载流水线失败', atTips: '提及机器人', + streaming: '流式传输', }, }, knowledge: {