From 48c9d66ab8151f7bf5ccdab6b4a6981e1b7b6600 Mon Sep 17 00:00:00 2001 From: fdc Date: Tue, 1 Jul 2025 18:03:05 +0800 Subject: [PATCH] =?UTF-8?q?chat=E4=B8=AD=E7=9A=84=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/pipeline/process/handlers/chat.py | 23 +++++++++++------------ pkg/pipeline/respback/respback.py | 23 +++++++++-------------- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index c90d283b..9b3e0cd5 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -70,15 +70,15 @@ class ChatMessageHandler(handler.MessageHandler): else: raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}') if is_stream: - accumulated_messages = [] - async for result in runner.run(query): - accumulated_messages.append(result) - query.resp_messages.append(result) + async for results in runner.run(query): + async for result in results: - self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}') + query.resp_messages.append(result) - if result.content is not None: - text_length += len(result.content) + self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}') + + if result.content is not None: + text_length += len(result.content) # current_chain = platform_message.MessageChain([]) # for msg in accumulated_messages: @@ -86,12 +86,11 @@ class ChatMessageHandler(handler.MessageHandler): # current_chain.append(platform_message.Plain(msg.content)) # query.resp_message_chain = [current_chain] - - - - + yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) + # query.resp_messages.append(results) + # self.ap.logger.info(f'对话({query.query_id})响应') + # yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) - yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) else: async for result in runner.run(query): diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index 7654896b..4ac4e1e3 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -7,6 +7,8 @@ import asyncio from ...platform.types import events as platform_events from ...platform.types import message as platform_message +from ...provider import entities as llm_entities + from .. import stage, entities from ...core import entities as core_entities @@ -38,17 +40,10 @@ class SendResponseBackStage(stage.PipelineStage): has_chunks = any(isinstance(msg, llm_entities.MessageChunk) for msg in query.resp_messages) if has_chunks and hasattr(query.adapter,'reply_message_chunk'): - - async def message_generator(): - for msg in query.resp_messages: - if isinstance(msg, llm_entities.MessageChunk): - yield msg.content - else: - yield msg.content await query.adapter.reply_message_chunk( message_source=query.message_event, - message_id=query.message_event.message_id, - message_generator=message_generator(), + message_id=query.query_id, + message_generator=query.resp_message_chain[-1], quote_origin=quote_origin, ) else: @@ -58,10 +53,10 @@ class SendResponseBackStage(stage.PipelineStage): quote_origin=quote_origin, ) - await query.adapter.reply_message( - message_source=query.message_event, - message=query.resp_message_chain[-1], - quote_origin=quote_origin, - ) + # await query.adapter.reply_message( + # message_source=query.message_event, + # message=query.resp_message_chain[-1], + # quote_origin=quote_origin, + # ) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)