Merge pull request #2041 from langbot-app/fix/websocket-chat-bug

Fix/websocket chat bug
This commit is contained in:
Guanchao Wang
2026-03-09 16:11:17 +08:00
committed by GitHub
2 changed files with 51 additions and 16 deletions

View File

@@ -37,16 +37,24 @@ class WebSocketSession:
id: str
message_lists: dict[str, list[WebSocketMessage]] = {}
"""消息列表 {pipeline_uuid: [messages]}"""
stream_message_indexes: dict[str, dict[str, int]] = {}
"""流式消息索引 {pipeline_uuid: {resp_message_id: message_index}}"""
def __init__(self, id: str):
self.id = id
self.message_lists = {}
self.stream_message_indexes = {}
def get_message_list(self, pipeline_uuid: str) -> list[WebSocketMessage]:
if pipeline_uuid not in self.message_lists:
self.message_lists[pipeline_uuid] = []
return self.message_lists[pipeline_uuid]
def get_stream_message_indexes(self, pipeline_uuid: str) -> dict[str, int]:
if pipeline_uuid not in self.stream_message_indexes:
self.stream_message_indexes[pipeline_uuid] = {}
return self.stream_message_indexes[pipeline_uuid]
class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
"""WebSocket适配器 - 支持双向实时通信"""
@@ -89,7 +97,13 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
target_id: str,
message: platform_message.MessageChain,
) -> dict:
"""发送消息 - 这里用于主动推送消息到前端"""
"""发送消息 - 这里用于主动推送消息到前端
对于 WebSocket 适配器,我们需要将消息广播到正确的 pipeline 连接。
target_id 可能是 launcher_id如 websocket_xxx或 pipeline_uuid。
我们需要尝试两种方式来确保消息能够送达。
"""
# 获取当前的 pipeline_uuid
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
session_type = 'group' if target_type == 'group' else 'person'
@@ -189,10 +203,16 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
message_list = session.get_message_list(pipeline_uuid)
stream_message_indexes = session.get_stream_message_indexes(pipeline_uuid)
# 检查是否是新的流式消息通过bot_message对象判断
# 如果列表为空或者最后一条消息已经is_final=True则创建新消息
if not message_list or message_list[-1].is_final:
# Streaming messages in LangBot have a stable resp_message_id during the same assistant reply.
# Use it as the primary key to avoid overwriting an old card from a previous reply.
resp_message_id = str(getattr(bot_message, 'resp_message_id', '') or '')
existing_index = stream_message_indexes.get(resp_message_id) if resp_message_id else None
message_is_final = is_final and bot_message.tool_calls is None
if existing_index is None or existing_index >= len(message_list):
# 创建新消息
msg_id = len(message_list) + 1
message_data = WebSocketMessage(
@@ -201,27 +221,31 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
content=str(message),
message_chain=[component.__dict__ for component in message],
timestamp=datetime.now().isoformat(),
is_final=is_final and bot_message.tool_calls is None,
is_final=message_is_final,
)
# 只有在is_final时才保存到历史记录
if is_final and bot_message.tool_calls is None:
message_list.append(message_data)
# 立即添加到历史记录即使is_final=False以便后续块可以更新它
message_list.append(message_data)
if resp_message_id:
stream_message_indexes[resp_message_id] = len(message_list) - 1
else:
# 更新最后一条消息
msg_id = message_list[-1].id
# 更新同一条流式消息
old_message = message_list[existing_index]
msg_id = old_message.id
message_data = WebSocketMessage(
id=msg_id,
role='assistant',
content=str(message),
message_chain=[component.__dict__ for component in message],
timestamp=message_list[-1].timestamp, # 保持原始时间戳
is_final=is_final and bot_message.tool_calls is None,
timestamp=old_message.timestamp, # 保持原始时间戳
is_final=message_is_final,
)
# 如果是final更新历史记录中的最后一条
if is_final and bot_message.tool_calls is None:
message_list[-1] = message_data
# 更新历史记录中的对应消息
message_list[existing_index] = message_data
if message_is_final and resp_message_id:
stream_message_indexes.pop(resp_message_id, None)
# 直接广播到所有该pipeline的连接包含session_type信息
await ws_connection_manager.broadcast_to_pipeline(
@@ -430,6 +454,10 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
if session_type == 'person':
if pipeline_uuid in self.websocket_person_session.message_lists:
self.websocket_person_session.message_lists[pipeline_uuid] = []
if pipeline_uuid in self.websocket_person_session.stream_message_indexes:
self.websocket_person_session.stream_message_indexes[pipeline_uuid] = {}
else:
if pipeline_uuid in self.websocket_group_session.message_lists:
self.websocket_group_session.message_lists[pipeline_uuid] = []
if pipeline_uuid in self.websocket_group_session.stream_message_indexes:
self.websocket_group_session.stream_message_indexes[pipeline_uuid] = {}

View File

@@ -337,7 +337,14 @@ class RuntimeConnectionHandler(handler.Handler):
)
messages_obj = [provider_message.Message.model_validate(message) for message in messages]
funcs_obj = [resource_tool.LLMTool.model_validate(func) for func in funcs]
# The func field is excluded during model_dump() in plugin side (marked as exclude=True),
# but it's a required field for LLMTool validation. We need to provide a placeholder
# function when reconstructing the LLMTool objects from serialized data.
async def _placeholder_func(**kwargs):
pass
funcs_obj = [resource_tool.LLMTool.model_validate({**func, 'func': _placeholder_func}) for func in funcs]
result = await llm_model.provider.invoke_llm(
query=None,