diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index c90d283b..9b3e0cd5 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -70,15 +70,15 @@ class ChatMessageHandler(handler.MessageHandler): else: raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}') if is_stream: - accumulated_messages = [] - async for result in runner.run(query): - accumulated_messages.append(result) - query.resp_messages.append(result) + async for results in runner.run(query): + async for result in results: - self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}') + query.resp_messages.append(result) - if result.content is not None: - text_length += len(result.content) + self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}') + + if result.content is not None: + text_length += len(result.content) # current_chain = platform_message.MessageChain([]) # for msg in accumulated_messages: @@ -86,12 +86,11 @@ class ChatMessageHandler(handler.MessageHandler): # current_chain.append(platform_message.Plain(msg.content)) # query.resp_message_chain = [current_chain] - - - - + yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) + # query.resp_messages.append(results) + # self.ap.logger.info(f'对话({query.query_id})响应') + # yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) - yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) else: async for result in runner.run(query): diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index 7654896b..4ac4e1e3 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -7,6 +7,8 @@ import asyncio from ...platform.types import events as platform_events from ...platform.types import message as platform_message +from ...provider import entities as llm_entities + from .. import stage, entities from ...core import entities as core_entities @@ -38,17 +40,10 @@ class SendResponseBackStage(stage.PipelineStage): has_chunks = any(isinstance(msg, llm_entities.MessageChunk) for msg in query.resp_messages) if has_chunks and hasattr(query.adapter,'reply_message_chunk'): - - async def message_generator(): - for msg in query.resp_messages: - if isinstance(msg, llm_entities.MessageChunk): - yield msg.content - else: - yield msg.content await query.adapter.reply_message_chunk( message_source=query.message_event, - message_id=query.message_event.message_id, - message_generator=message_generator(), + message_id=query.query_id, + message_generator=query.resp_message_chain[-1], quote_origin=quote_origin, ) else: @@ -58,10 +53,10 @@ class SendResponseBackStage(stage.PipelineStage): quote_origin=quote_origin, ) - await query.adapter.reply_message( - message_source=query.message_event, - message=query.resp_message_chain[-1], - quote_origin=quote_origin, - ) + # await query.adapter.reply_message( + # message_source=query.message_event, + # message=query.resp_message_chain[-1], + # quote_origin=quote_origin, + # ) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)