From 2b8bd1cc719ff3912556a075937ec373fd3ced6a Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Mon, 9 Mar 2026 12:16:53 +0800 Subject: [PATCH] fix: invoke_llm failed when use plugin --- .../pkg/platform/sources/websocket_adapter.py | 50 +++++++++++++------ src/langbot/pkg/plugin/handler.py | 12 ++++- 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/src/langbot/pkg/platform/sources/websocket_adapter.py b/src/langbot/pkg/platform/sources/websocket_adapter.py index d877c274..de35df9e 100644 --- a/src/langbot/pkg/platform/sources/websocket_adapter.py +++ b/src/langbot/pkg/platform/sources/websocket_adapter.py @@ -37,16 +37,24 @@ class WebSocketSession: id: str message_lists: dict[str, list[WebSocketMessage]] = {} """消息列表 {pipeline_uuid: [messages]}""" + stream_message_indexes: dict[str, dict[str, int]] = {} + """流式消息索引 {pipeline_uuid: {resp_message_id: message_index}}""" def __init__(self, id: str): self.id = id self.message_lists = {} + self.stream_message_indexes = {} def get_message_list(self, pipeline_uuid: str) -> list[WebSocketMessage]: if pipeline_uuid not in self.message_lists: self.message_lists[pipeline_uuid] = [] return self.message_lists[pipeline_uuid] + def get_stream_message_indexes(self, pipeline_uuid: str) -> dict[str, int]: + if pipeline_uuid not in self.stream_message_indexes: + self.stream_message_indexes[pipeline_uuid] = {} + return self.stream_message_indexes[pipeline_uuid] + class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): """WebSocket适配器 - 支持双向实时通信""" @@ -189,10 +197,16 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter) pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person' message_list = session.get_message_list(pipeline_uuid) + stream_message_indexes = session.get_stream_message_indexes(pipeline_uuid) - # 检查是否是新的流式消息(通过bot_message对象判断) - # 如果列表为空,或者最后一条消息已经is_final=True,则创建新消息 - if not message_list or message_list[-1].is_final: + # Streaming messages in LangBot have a stable resp_message_id during the same assistant reply. + # Use it as the primary key to avoid overwriting an old card from a previous reply. + resp_message_id = str(getattr(bot_message, 'resp_message_id', '') or '') + existing_index = stream_message_indexes.get(resp_message_id) if resp_message_id else None + + message_is_final = is_final and bot_message.tool_calls is None + + if existing_index is None or existing_index >= len(message_list): # 创建新消息 msg_id = len(message_list) + 1 message_data = WebSocketMessage( @@ -201,27 +215,31 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter) content=str(message), message_chain=[component.__dict__ for component in message], timestamp=datetime.now().isoformat(), - is_final=is_final and bot_message.tool_calls is None, + is_final=message_is_final, ) - # 只有在is_final时才保存到历史记录 - if is_final and bot_message.tool_calls is None: - message_list.append(message_data) + # 立即添加到历史记录(即使is_final=False),以便后续块可以更新它 + message_list.append(message_data) + if resp_message_id: + stream_message_indexes[resp_message_id] = len(message_list) - 1 else: - # 更新最后一条消息 - msg_id = message_list[-1].id + # 更新同一条流式消息 + old_message = message_list[existing_index] + msg_id = old_message.id message_data = WebSocketMessage( id=msg_id, role='assistant', content=str(message), message_chain=[component.__dict__ for component in message], - timestamp=message_list[-1].timestamp, # 保持原始时间戳 - is_final=is_final and bot_message.tool_calls is None, + timestamp=old_message.timestamp, # 保持原始时间戳 + is_final=message_is_final, ) - # 如果是final,更新历史记录中的最后一条 - if is_final and bot_message.tool_calls is None: - message_list[-1] = message_data + # 更新历史记录中的对应消息 + message_list[existing_index] = message_data + + if message_is_final and resp_message_id: + stream_message_indexes.pop(resp_message_id, None) # 直接广播到所有该pipeline的连接,包含session_type信息 await ws_connection_manager.broadcast_to_pipeline( @@ -430,6 +448,10 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter) if session_type == 'person': if pipeline_uuid in self.websocket_person_session.message_lists: self.websocket_person_session.message_lists[pipeline_uuid] = [] + if pipeline_uuid in self.websocket_person_session.stream_message_indexes: + self.websocket_person_session.stream_message_indexes[pipeline_uuid] = {} else: if pipeline_uuid in self.websocket_group_session.message_lists: self.websocket_group_session.message_lists[pipeline_uuid] = [] + if pipeline_uuid in self.websocket_group_session.stream_message_indexes: + self.websocket_group_session.stream_message_indexes[pipeline_uuid] = {} diff --git a/src/langbot/pkg/plugin/handler.py b/src/langbot/pkg/plugin/handler.py index dbe4698c..590911cf 100644 --- a/src/langbot/pkg/plugin/handler.py +++ b/src/langbot/pkg/plugin/handler.py @@ -337,7 +337,17 @@ class RuntimeConnectionHandler(handler.Handler): ) messages_obj = [provider_message.Message.model_validate(message) for message in messages] - funcs_obj = [resource_tool.LLMTool.model_validate(func) for func in funcs] + + # The func field is excluded during model_dump() in plugin side (marked as exclude=True), + # but it's a required field for LLMTool validation. We need to provide a placeholder + # function when reconstructing the LLMTool objects from serialized data. + async def _placeholder_func(**kwargs): + pass + + funcs_obj = [ + resource_tool.LLMTool.model_validate({**func, 'func': _placeholder_func}) + for func in funcs + ] result = await llm_model.provider.invoke_llm( query=None,