mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
fix: invoke_llm failed when use plugin
This commit is contained in:
@@ -37,16 +37,24 @@ class WebSocketSession:
|
||||
id: str
|
||||
message_lists: dict[str, list[WebSocketMessage]] = {}
|
||||
"""消息列表 {pipeline_uuid: [messages]}"""
|
||||
stream_message_indexes: dict[str, dict[str, int]] = {}
|
||||
"""流式消息索引 {pipeline_uuid: {resp_message_id: message_index}}"""
|
||||
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
self.message_lists = {}
|
||||
self.stream_message_indexes = {}
|
||||
|
||||
def get_message_list(self, pipeline_uuid: str) -> list[WebSocketMessage]:
|
||||
if pipeline_uuid not in self.message_lists:
|
||||
self.message_lists[pipeline_uuid] = []
|
||||
return self.message_lists[pipeline_uuid]
|
||||
|
||||
def get_stream_message_indexes(self, pipeline_uuid: str) -> dict[str, int]:
|
||||
if pipeline_uuid not in self.stream_message_indexes:
|
||||
self.stream_message_indexes[pipeline_uuid] = {}
|
||||
return self.stream_message_indexes[pipeline_uuid]
|
||||
|
||||
|
||||
class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
"""WebSocket适配器 - 支持双向实时通信"""
|
||||
@@ -189,10 +197,16 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||||
session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
|
||||
message_list = session.get_message_list(pipeline_uuid)
|
||||
stream_message_indexes = session.get_stream_message_indexes(pipeline_uuid)
|
||||
|
||||
# 检查是否是新的流式消息(通过bot_message对象判断)
|
||||
# 如果列表为空,或者最后一条消息已经is_final=True,则创建新消息
|
||||
if not message_list or message_list[-1].is_final:
|
||||
# Streaming messages in LangBot have a stable resp_message_id during the same assistant reply.
|
||||
# Use it as the primary key to avoid overwriting an old card from a previous reply.
|
||||
resp_message_id = str(getattr(bot_message, 'resp_message_id', '') or '')
|
||||
existing_index = stream_message_indexes.get(resp_message_id) if resp_message_id else None
|
||||
|
||||
message_is_final = is_final and bot_message.tool_calls is None
|
||||
|
||||
if existing_index is None or existing_index >= len(message_list):
|
||||
# 创建新消息
|
||||
msg_id = len(message_list) + 1
|
||||
message_data = WebSocketMessage(
|
||||
@@ -201,27 +215,31 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
content=str(message),
|
||||
message_chain=[component.__dict__ for component in message],
|
||||
timestamp=datetime.now().isoformat(),
|
||||
is_final=is_final and bot_message.tool_calls is None,
|
||||
is_final=message_is_final,
|
||||
)
|
||||
|
||||
# 只有在is_final时才保存到历史记录
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
message_list.append(message_data)
|
||||
# 立即添加到历史记录(即使is_final=False),以便后续块可以更新它
|
||||
message_list.append(message_data)
|
||||
if resp_message_id:
|
||||
stream_message_indexes[resp_message_id] = len(message_list) - 1
|
||||
else:
|
||||
# 更新最后一条消息
|
||||
msg_id = message_list[-1].id
|
||||
# 更新同一条流式消息
|
||||
old_message = message_list[existing_index]
|
||||
msg_id = old_message.id
|
||||
message_data = WebSocketMessage(
|
||||
id=msg_id,
|
||||
role='assistant',
|
||||
content=str(message),
|
||||
message_chain=[component.__dict__ for component in message],
|
||||
timestamp=message_list[-1].timestamp, # 保持原始时间戳
|
||||
is_final=is_final and bot_message.tool_calls is None,
|
||||
timestamp=old_message.timestamp, # 保持原始时间戳
|
||||
is_final=message_is_final,
|
||||
)
|
||||
|
||||
# 如果是final,更新历史记录中的最后一条
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
message_list[-1] = message_data
|
||||
# 更新历史记录中的对应消息
|
||||
message_list[existing_index] = message_data
|
||||
|
||||
if message_is_final and resp_message_id:
|
||||
stream_message_indexes.pop(resp_message_id, None)
|
||||
|
||||
# 直接广播到所有该pipeline的连接,包含session_type信息
|
||||
await ws_connection_manager.broadcast_to_pipeline(
|
||||
@@ -430,6 +448,10 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
if session_type == 'person':
|
||||
if pipeline_uuid in self.websocket_person_session.message_lists:
|
||||
self.websocket_person_session.message_lists[pipeline_uuid] = []
|
||||
if pipeline_uuid in self.websocket_person_session.stream_message_indexes:
|
||||
self.websocket_person_session.stream_message_indexes[pipeline_uuid] = {}
|
||||
else:
|
||||
if pipeline_uuid in self.websocket_group_session.message_lists:
|
||||
self.websocket_group_session.message_lists[pipeline_uuid] = []
|
||||
if pipeline_uuid in self.websocket_group_session.stream_message_indexes:
|
||||
self.websocket_group_session.stream_message_indexes[pipeline_uuid] = {}
|
||||
|
||||
@@ -337,7 +337,17 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
)
|
||||
|
||||
messages_obj = [provider_message.Message.model_validate(message) for message in messages]
|
||||
funcs_obj = [resource_tool.LLMTool.model_validate(func) for func in funcs]
|
||||
|
||||
# The func field is excluded during model_dump() in plugin side (marked as exclude=True),
|
||||
# but it's a required field for LLMTool validation. We need to provide a placeholder
|
||||
# function when reconstructing the LLMTool objects from serialized data.
|
||||
async def _placeholder_func(**kwargs):
|
||||
pass
|
||||
|
||||
funcs_obj = [
|
||||
resource_tool.LLMTool.model_validate({**func, 'func': _placeholder_func})
|
||||
for func in funcs
|
||||
]
|
||||
|
||||
result = await llm_model.provider.invoke_llm(
|
||||
query=None,
|
||||
|
||||
Reference in New Issue
Block a user