From 051fffd41e16813bcf91aa8dfb784e9fac8145ab Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Thu, 7 Aug 2025 21:56:40 +0800 Subject: [PATCH] fix: stash --- .../controller/groups/pipelines/webchat.py | 3 +- pkg/platform/sources/webchat.py | 39 ++++++++++++------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/pkg/api/http/controller/groups/pipelines/webchat.py b/pkg/api/http/controller/groups/pipelines/webchat.py index ae201934..7eea471a 100644 --- a/pkg/api/http/controller/groups/pipelines/webchat.py +++ b/pkg/api/http/controller/groups/pipelines/webchat.py @@ -48,8 +48,7 @@ class WebChatDebugRouterGroup(group.RouterGroup): } return quart.Response(stream_generator(generator), mimetype='text/event-stream',headers=headers) - else: - # result = await webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj) + else: # non-stream result = None async for message in webchat_adapter.send_webchat_message( pipeline_uuid, session_type, message_chain_obj diff --git a/pkg/platform/sources/webchat.py b/pkg/platform/sources/webchat.py index 52fc9294..2c39afbb 100644 --- a/pkg/platform/sources/webchat.py +++ b/pkg/platform/sources/webchat.py @@ -26,7 +26,7 @@ class WebChatSession: id: str message_lists: dict[str, list[WebChatMessage]] = {} resp_waiters: dict[int, asyncio.Future[WebChatMessage]] - resp_queues = dict[int, asyncio.Queue[WebChatMessage]] + resp_queues: dict[int, asyncio.Queue[WebChatMessage]] def __init__(self, id: str): self.id = id @@ -109,9 +109,9 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): # notify waiter if isinstance(message_source, platform_events.FriendMessage): - self.webchat_person_session.resp_waiters[message_source.message_chain.message_id].set_result(message_data) + self.webchat_person_session.resp_queues[message_source.message_chain.message_id].put(message_data) elif isinstance(message_source, platform_events.GroupMessage): - self.webchat_group_session.resp_waiters[message_source.message_chain.message_id].set_result(message_data) + self.webchat_group_session.resp_queues[message_source.message_chain.message_id].put(message_data) return message_data.model_dump() @@ -205,9 +205,8 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): message_id = len(use_session.get_message_list(pipeline_uuid)) + 1 - if is_stream: - use_session.resp_queues[message_id] = asyncio.Queue() - logger.debug(f'Initialized queue for message_id: {message_id}') + use_session.resp_queues[message_id] = asyncio.Queue() + logger.debug(f'Initialized queue for message_id: {message_id}') use_session.get_message_list(pipeline_uuid).append( WebChatMessage( @@ -242,6 +241,7 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): self.ap.platform_mgr.webchat_proxy_bot.bot_entity.use_pipeline_uuid = pipeline_uuid + # trigger pipeline if event.__class__ in self.listeners: await self.listeners[event.__class__](event, self) @@ -257,20 +257,31 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): yield resp_message.model_dump() break yield resp_message.model_dump() + use_session.resp_queues.pop(message_id) - else: + else: # non-stream # set waiter - waiter = asyncio.Future[WebChatMessage]() - use_session.resp_waiters[message_id] = waiter - waiter.add_done_callback(lambda future: use_session.resp_waiters.pop(message_id)) + # waiter = asyncio.Future[WebChatMessage]() + # use_session.resp_waiters[message_id] = waiter + # # waiter.add_done_callback(lambda future: use_session.resp_waiters.pop(message_id)) - resp_message = await waiter + # resp_message = await waiter - resp_message.id = len(use_session.get_message_list(pipeline_uuid)) + 1 + # resp_message.id = len(use_session.get_message_list(pipeline_uuid)) + 1 - use_session.get_message_list(pipeline_uuid).append(resp_message) + # use_session.get_message_list(pipeline_uuid).append(resp_message) - yield resp_message.model_dump() + # yield resp_message.model_dump() + queue = use_session.resp_queues[message_id] + while True: + resp_message = await queue.get() + resp_message.id = msg_id + if resp_message.is_final: + resp_message.id = msg_id + use_session.get_message_list(pipeline_uuid).append(resp_message) + yield resp_message.model_dump() + break + yield resp_message.model_dump() def get_webchat_messages(self, pipeline_uuid: str, session_type: str) -> list[dict]: """获取调试消息历史"""