mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-04 12:56:02 +00:00
wq
Merge branch 'master' into feat/dbm20-rag
This commit is contained in:
@@ -64,7 +64,7 @@ dependencies = [
|
||||
"chromadb>=0.4.24",
|
||||
"qdrant-client (>=1.15.1,<2.0.0)",
|
||||
"pyseekdb==1.1.0.post3",
|
||||
"langbot-plugin==0.3.0rc1",
|
||||
"langbot-plugin==0.3.0",
|
||||
"asyncpg>=0.30.0",
|
||||
"line-bot-sdk>=3.19.0",
|
||||
"tboxsdk>=0.0.10",
|
||||
|
||||
105
src/langbot/pkg/pipeline/config_coercion.py
Normal file
105
src/langbot/pkg/pipeline/config_coercion.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# metadata type -> coercion function
|
||||
_COERCE_MAP = {
|
||||
'integer': lambda v: int(v),
|
||||
'number': lambda v: float(v),
|
||||
'float': lambda v: float(v),
|
||||
}
|
||||
|
||||
|
||||
def _coerce_bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if isinstance(v, str):
|
||||
if v.lower() == 'true':
|
||||
return True
|
||||
if v.lower() == 'false':
|
||||
return False
|
||||
raise ValueError(f'Cannot convert string {v!r} to bool')
|
||||
return bool(v)
|
||||
|
||||
|
||||
def _coerce_value(value, expected_type: str):
|
||||
"""Convert a single value to the expected type.
|
||||
|
||||
Returns the converted value, or the original value if no conversion needed.
|
||||
"""
|
||||
if value is None:
|
||||
return value
|
||||
|
||||
if expected_type == 'boolean':
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return _coerce_bool(value)
|
||||
|
||||
coerce_fn = _COERCE_MAP.get(expected_type)
|
||||
if coerce_fn is None:
|
||||
return value
|
||||
|
||||
# Already the correct type
|
||||
if expected_type == 'integer' and isinstance(value, int) and not isinstance(value, bool):
|
||||
return value
|
||||
if expected_type in ('number', 'float') and isinstance(value, (int, float)) and not isinstance(value, bool):
|
||||
return float(value)
|
||||
|
||||
return coerce_fn(value)
|
||||
|
||||
|
||||
def coerce_pipeline_config(
|
||||
config: dict,
|
||||
*metadata_list: dict,
|
||||
) -> None:
|
||||
"""Coerce pipeline config values according to metadata type definitions.
|
||||
|
||||
Walks each metadata dict (trigger, safety, ai, output) and converts
|
||||
config values in-place so that strings coming from the JSON column are
|
||||
cast to their declared types (integer, number/float, boolean).
|
||||
|
||||
Args:
|
||||
config: The pipeline config dict to modify in-place.
|
||||
*metadata_list: Metadata dicts loaded from the YAML templates.
|
||||
"""
|
||||
for meta in metadata_list:
|
||||
section_name = meta.get('name')
|
||||
if not section_name or section_name not in config:
|
||||
continue
|
||||
|
||||
section = config[section_name]
|
||||
if not isinstance(section, dict):
|
||||
continue
|
||||
|
||||
for stage_def in meta.get('stages', []):
|
||||
stage_name = stage_def.get('name')
|
||||
if not stage_name or stage_name not in section:
|
||||
continue
|
||||
|
||||
stage_config = section[stage_name]
|
||||
if not isinstance(stage_config, dict):
|
||||
continue
|
||||
|
||||
for field_def in stage_def.get('config', []):
|
||||
field_name = field_def.get('name')
|
||||
field_type = field_def.get('type')
|
||||
if not field_name or not field_type or field_name not in stage_config:
|
||||
continue
|
||||
|
||||
old_value = stage_config[field_name]
|
||||
try:
|
||||
new_value = _coerce_value(old_value, field_type)
|
||||
if new_value is not old_value:
|
||||
stage_config[field_name] = new_value
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(
|
||||
'Failed to coerce config %s.%s.%s (%r) to %s: %s',
|
||||
section_name,
|
||||
stage_name,
|
||||
field_name,
|
||||
old_value,
|
||||
field_type,
|
||||
e,
|
||||
)
|
||||
@@ -13,6 +13,7 @@ import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.events as events
|
||||
from ..utils import importutil
|
||||
from .config_coercion import coerce_pipeline_config
|
||||
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
@@ -420,6 +421,14 @@ class PipelineManager:
|
||||
elif isinstance(pipeline_entity, dict):
|
||||
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
|
||||
|
||||
coerce_pipeline_config(
|
||||
pipeline_entity.config,
|
||||
getattr(self.ap, 'pipeline_config_meta_trigger', {'name': 'trigger', 'stages': []}),
|
||||
getattr(self.ap, 'pipeline_config_meta_safety', {'name': 'safety', 'stages': []}),
|
||||
getattr(self.ap, 'pipeline_config_meta_ai', {'name': 'ai', 'stages': []}),
|
||||
getattr(self.ap, 'pipeline_config_meta_output', {'name': 'output', 'stages': []}),
|
||||
)
|
||||
|
||||
# initialize stage containers according to pipeline_entity.stages
|
||||
stage_containers: list[StageInstContainer] = []
|
||||
for stage_name in pipeline_entity.stages:
|
||||
|
||||
@@ -282,6 +282,8 @@ class PlatformManager:
|
||||
return runtime_bot
|
||||
|
||||
async def get_bot_by_uuid(self, bot_uuid: str) -> RuntimeBot | None:
|
||||
if self.websocket_proxy_bot and self.websocket_proxy_bot.bot_entity.uuid == bot_uuid:
|
||||
return self.websocket_proxy_bot
|
||||
for bot in self.bots:
|
||||
if bot.bot_entity.uuid == bot_uuid:
|
||||
return bot
|
||||
|
||||
@@ -37,16 +37,24 @@ class WebSocketSession:
|
||||
id: str
|
||||
message_lists: dict[str, list[WebSocketMessage]] = {}
|
||||
"""消息列表 {pipeline_uuid: [messages]}"""
|
||||
stream_message_indexes: dict[str, dict[str, int]] = {}
|
||||
"""流式消息索引 {pipeline_uuid: {resp_message_id: message_index}}"""
|
||||
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
self.message_lists = {}
|
||||
self.stream_message_indexes = {}
|
||||
|
||||
def get_message_list(self, pipeline_uuid: str) -> list[WebSocketMessage]:
|
||||
if pipeline_uuid not in self.message_lists:
|
||||
self.message_lists[pipeline_uuid] = []
|
||||
return self.message_lists[pipeline_uuid]
|
||||
|
||||
def get_stream_message_indexes(self, pipeline_uuid: str) -> dict[str, int]:
|
||||
if pipeline_uuid not in self.stream_message_indexes:
|
||||
self.stream_message_indexes[pipeline_uuid] = {}
|
||||
return self.stream_message_indexes[pipeline_uuid]
|
||||
|
||||
|
||||
class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
"""WebSocket适配器 - 支持双向实时通信"""
|
||||
@@ -89,20 +97,46 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
target_id: str,
|
||||
message: platform_message.MessageChain,
|
||||
) -> dict:
|
||||
"""发送消息 - 这里用于主动推送消息到前端"""
|
||||
message_data = {
|
||||
'type': 'bot_message',
|
||||
'target_type': target_type,
|
||||
'target_id': target_id,
|
||||
'content': str(message),
|
||||
'message_chain': [component.__dict__ for component in message],
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
}
|
||||
"""发送消息 - 这里用于主动推送消息到前端
|
||||
|
||||
# 推送到所有相关连接
|
||||
await self.outbound_message_queue.put(message_data)
|
||||
对于 WebSocket 适配器,我们需要将消息广播到正确的 pipeline 连接。
|
||||
target_id 可能是 launcher_id(如 websocket_xxx)或 pipeline_uuid。
|
||||
我们需要尝试两种方式来确保消息能够送达。
|
||||
"""
|
||||
# 获取当前的 pipeline_uuid
|
||||
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||||
session_type = 'group' if target_type == 'group' else 'person'
|
||||
|
||||
return message_data
|
||||
# 选择会话
|
||||
session = self.websocket_group_session if session_type == 'group' else self.websocket_person_session
|
||||
|
||||
# 生成唯一消息ID
|
||||
msg_id = len(session.get_message_list(pipeline_uuid)) + 1
|
||||
|
||||
message_data = WebSocketMessage(
|
||||
id=msg_id,
|
||||
role='assistant',
|
||||
content=str(message),
|
||||
message_chain=[component.__dict__ for component in message],
|
||||
timestamp=datetime.now().isoformat(),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# 保存到历史记录
|
||||
session.get_message_list(pipeline_uuid).append(message_data)
|
||||
|
||||
# 直接广播到当前pipeline的连接
|
||||
await ws_connection_manager.broadcast_to_pipeline(
|
||||
pipeline_uuid,
|
||||
{
|
||||
'type': 'response',
|
||||
'session_type': session_type,
|
||||
'data': message_data.model_dump(),
|
||||
},
|
||||
session_type=session_type,
|
||||
)
|
||||
|
||||
return message_data.model_dump()
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
@@ -169,10 +203,16 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||||
session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
|
||||
message_list = session.get_message_list(pipeline_uuid)
|
||||
stream_message_indexes = session.get_stream_message_indexes(pipeline_uuid)
|
||||
|
||||
# 检查是否是新的流式消息(通过bot_message对象判断)
|
||||
# 如果列表为空,或者最后一条消息已经is_final=True,则创建新消息
|
||||
if not message_list or message_list[-1].is_final:
|
||||
# Streaming messages in LangBot have a stable resp_message_id during the same assistant reply.
|
||||
# Use it as the primary key to avoid overwriting an old card from a previous reply.
|
||||
resp_message_id = str(getattr(bot_message, 'resp_message_id', '') or '')
|
||||
existing_index = stream_message_indexes.get(resp_message_id) if resp_message_id else None
|
||||
|
||||
message_is_final = is_final and bot_message.tool_calls is None
|
||||
|
||||
if existing_index is None or existing_index >= len(message_list):
|
||||
# 创建新消息
|
||||
msg_id = len(message_list) + 1
|
||||
message_data = WebSocketMessage(
|
||||
@@ -181,27 +221,31 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
content=str(message),
|
||||
message_chain=[component.__dict__ for component in message],
|
||||
timestamp=datetime.now().isoformat(),
|
||||
is_final=is_final and bot_message.tool_calls is None,
|
||||
is_final=message_is_final,
|
||||
)
|
||||
|
||||
# 只有在is_final时才保存到历史记录
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
message_list.append(message_data)
|
||||
# 立即添加到历史记录(即使is_final=False),以便后续块可以更新它
|
||||
message_list.append(message_data)
|
||||
if resp_message_id:
|
||||
stream_message_indexes[resp_message_id] = len(message_list) - 1
|
||||
else:
|
||||
# 更新最后一条消息
|
||||
msg_id = message_list[-1].id
|
||||
# 更新同一条流式消息
|
||||
old_message = message_list[existing_index]
|
||||
msg_id = old_message.id
|
||||
message_data = WebSocketMessage(
|
||||
id=msg_id,
|
||||
role='assistant',
|
||||
content=str(message),
|
||||
message_chain=[component.__dict__ for component in message],
|
||||
timestamp=message_list[-1].timestamp, # 保持原始时间戳
|
||||
is_final=is_final and bot_message.tool_calls is None,
|
||||
timestamp=old_message.timestamp, # 保持原始时间戳
|
||||
is_final=message_is_final,
|
||||
)
|
||||
|
||||
# 如果是final,更新历史记录中的最后一条
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
message_list[-1] = message_data
|
||||
# 更新历史记录中的对应消息
|
||||
message_list[existing_index] = message_data
|
||||
|
||||
if message_is_final and resp_message_id:
|
||||
stream_message_indexes.pop(resp_message_id, None)
|
||||
|
||||
# 直接广播到所有该pipeline的连接,包含session_type信息
|
||||
await ws_connection_manager.broadcast_to_pipeline(
|
||||
@@ -410,6 +454,10 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
|
||||
if session_type == 'person':
|
||||
if pipeline_uuid in self.websocket_person_session.message_lists:
|
||||
self.websocket_person_session.message_lists[pipeline_uuid] = []
|
||||
if pipeline_uuid in self.websocket_person_session.stream_message_indexes:
|
||||
self.websocket_person_session.stream_message_indexes[pipeline_uuid] = {}
|
||||
else:
|
||||
if pipeline_uuid in self.websocket_group_session.message_lists:
|
||||
self.websocket_group_session.message_lists[pipeline_uuid] = []
|
||||
if pipeline_uuid in self.websocket_group_session.stream_message_indexes:
|
||||
self.websocket_group_session.stream_message_indexes[pipeline_uuid] = {}
|
||||
|
||||
@@ -337,7 +337,14 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
)
|
||||
|
||||
messages_obj = [provider_message.Message.model_validate(message) for message in messages]
|
||||
funcs_obj = [resource_tool.LLMTool.model_validate(func) for func in funcs]
|
||||
|
||||
# The func field is excluded during model_dump() in plugin side (marked as exclude=True),
|
||||
# but it's a required field for LLMTool validation. We need to provide a placeholder
|
||||
# function when reconstructing the LLMTool objects from serialized data.
|
||||
async def _placeholder_func(**kwargs):
|
||||
pass
|
||||
|
||||
funcs_obj = [resource_tool.LLMTool.model_validate({**func, 'func': _placeholder_func}) for func in funcs]
|
||||
|
||||
result = await llm_model.provider.invoke_llm(
|
||||
query=None,
|
||||
|
||||
113
tests/unit_tests/pipeline/test_config_coercion.py
Normal file
113
tests/unit_tests/pipeline/test_config_coercion.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Unit tests for config_coercion module"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.pipeline.config_coercion import _coerce_value, coerce_pipeline_config
|
||||
|
||||
|
||||
class TestCoerceValue:
|
||||
"""Tests for _coerce_value function"""
|
||||
|
||||
def test_none_passthrough(self):
|
||||
assert _coerce_value(None, 'integer') is None
|
||||
assert _coerce_value(None, 'boolean') is None
|
||||
|
||||
def test_string_to_integer(self):
|
||||
assert _coerce_value('120', 'integer') == 120
|
||||
assert _coerce_value('0', 'integer') == 0
|
||||
assert _coerce_value('-5', 'integer') == -5
|
||||
|
||||
def test_integer_passthrough(self):
|
||||
assert _coerce_value(42, 'integer') == 42
|
||||
|
||||
def test_string_to_float(self):
|
||||
assert _coerce_value('3.14', 'number') == 3.14
|
||||
assert _coerce_value('3.14', 'float') == 3.14
|
||||
|
||||
def test_int_to_float(self):
|
||||
assert _coerce_value(3, 'number') == 3.0
|
||||
assert isinstance(_coerce_value(3, 'number'), float)
|
||||
|
||||
def test_float_passthrough(self):
|
||||
assert _coerce_value(3.14, 'float') == 3.14
|
||||
|
||||
def test_string_to_bool(self):
|
||||
assert _coerce_value('true', 'boolean') is True
|
||||
assert _coerce_value('True', 'boolean') is True
|
||||
assert _coerce_value('false', 'boolean') is False
|
||||
assert _coerce_value('False', 'boolean') is False
|
||||
|
||||
def test_bool_passthrough(self):
|
||||
assert _coerce_value(True, 'boolean') is True
|
||||
assert _coerce_value(False, 'boolean') is False
|
||||
|
||||
def test_invalid_bool_string_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
_coerce_value('notabool', 'boolean')
|
||||
|
||||
def test_unknown_type_passthrough(self):
|
||||
assert _coerce_value('hello', 'string') == 'hello'
|
||||
assert _coerce_value('hello', 'unknown') == 'hello'
|
||||
|
||||
def test_invalid_integer_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
_coerce_value('abc', 'integer')
|
||||
|
||||
|
||||
class TestCoercePipelineConfig:
|
||||
"""Tests for coerce_pipeline_config function"""
|
||||
|
||||
def _make_meta(self, section_name: str, stage_name: str, fields: list[dict]) -> dict:
|
||||
return {
|
||||
'name': section_name,
|
||||
'stages': [{'name': stage_name, 'config': fields}],
|
||||
}
|
||||
|
||||
def test_coerce_integer_in_config(self):
|
||||
config = {'trigger': {'misc': {'timeout': '120'}}}
|
||||
meta = self._make_meta('trigger', 'misc', [{'name': 'timeout', 'type': 'integer'}])
|
||||
coerce_pipeline_config(config, meta)
|
||||
assert config['trigger']['misc']['timeout'] == 120
|
||||
|
||||
def test_coerce_boolean_in_config(self):
|
||||
config = {'output': {'misc': {'at-sender': 'true'}}}
|
||||
meta = self._make_meta('output', 'misc', [{'name': 'at-sender', 'type': 'boolean'}])
|
||||
coerce_pipeline_config(config, meta)
|
||||
assert config['output']['misc']['at-sender'] is True
|
||||
|
||||
def test_missing_section_skipped(self):
|
||||
config = {'ai': {}}
|
||||
meta = self._make_meta('trigger', 'misc', [{'name': 'x', 'type': 'integer'}])
|
||||
coerce_pipeline_config(config, meta) # should not raise
|
||||
|
||||
def test_missing_field_skipped(self):
|
||||
config = {'trigger': {'misc': {}}}
|
||||
meta = self._make_meta('trigger', 'misc', [{'name': 'nonexistent', 'type': 'integer'}])
|
||||
coerce_pipeline_config(config, meta) # should not raise
|
||||
|
||||
def test_invalid_value_logs_warning(self, caplog):
|
||||
config = {'trigger': {'misc': {'timeout': 'abc'}}}
|
||||
meta = self._make_meta('trigger', 'misc', [{'name': 'timeout', 'type': 'integer'}])
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
coerce_pipeline_config(config, meta)
|
||||
assert config['trigger']['misc']['timeout'] == 'abc' # unchanged
|
||||
assert 'Failed to coerce' in caplog.text
|
||||
|
||||
def test_empty_metadata(self):
|
||||
config = {'trigger': {'misc': {'timeout': '120'}}}
|
||||
coerce_pipeline_config(config) # no metadata args, should not raise
|
||||
|
||||
def test_multiple_metadata(self):
|
||||
config = {
|
||||
'trigger': {'misc': {'timeout': '120'}},
|
||||
'output': {'misc': {'at-sender': 'false'}},
|
||||
}
|
||||
meta_trigger = self._make_meta('trigger', 'misc', [{'name': 'timeout', 'type': 'integer'}])
|
||||
meta_output = self._make_meta('output', 'misc', [{'name': 'at-sender', 'type': 'boolean'}])
|
||||
coerce_pipeline_config(config, meta_trigger, meta_output)
|
||||
assert config['trigger']['misc']['timeout'] == 120
|
||||
assert config['output']['misc']['at-sender'] is False
|
||||
8
uv.lock
generated
8
uv.lock
generated
@@ -1937,7 +1937,7 @@ requires-dist = [
|
||||
{ name = "ebooklib", specifier = ">=0.18" },
|
||||
{ name = "gewechat-client", specifier = ">=0.1.5" },
|
||||
{ name = "html2text", specifier = ">=2024.2.26" },
|
||||
{ name = "langbot-plugin", specifier = "==0.3.0rc1" },
|
||||
{ name = "langbot-plugin", specifier = "==0.3.0" },
|
||||
{ name = "langchain", specifier = ">=0.2.0" },
|
||||
{ name = "langchain-text-splitters", specifier = ">=0.0.1" },
|
||||
{ name = "lark-oapi", specifier = ">=1.4.15" },
|
||||
@@ -1993,7 +1993,7 @@ dev = [
|
||||
|
||||
[[package]]
|
||||
name = "langbot-plugin"
|
||||
version = "0.3.0rc1"
|
||||
version = "0.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
@@ -2011,9 +2011,9 @@ dependencies = [
|
||||
{ name = "watchdog" },
|
||||
{ name = "websockets" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/69/42/1bc1d50562182ca325960edeb9600f576c90b3352f2e5e19d11bb7d28d30/langbot_plugin-0.3.0rc1.tar.gz", hash = "sha256:0ecf0e646ea07aee9fb99d8283337b0926de8322b012c8e3a514ba54a4530598", size = 169886, upload-time = "2026-03-05T14:55:02.131Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8d/e5/3686b3225e5f2ee6e19a6050bb981b49a91f2450dff83deb5dfba13b3a2a/langbot_plugin-0.3.0.tar.gz", hash = "sha256:9add2d6e81c8cc7281863e4a92a33ed6228dcc0243f4327ac4062edc962dbf98", size = 169751, upload-time = "2026-03-08T09:54:27.102Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/77/67/048798d05dbfbffcb28093cbfb986ab81a0f8ee3db0daf9235ad2357c3ed/langbot_plugin-0.3.0rc1-py3-none-any.whl", hash = "sha256:d72991ecd527c9c1b1ec1526b40b67369bcdb89a79b277920c261de76fd069d8", size = 144141, upload-time = "2026-03-05T14:55:03.191Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/51/18f0c1446bcb6712ff3d31d81ea708e3f0e671fde5da69598204a1df977d/langbot_plugin-0.3.0-py3-none-any.whl", hash = "sha256:37bfd3ce507448a6ec4444bec1bc6da1c9911c9df144dfd428febb71122077a6", size = 144096, upload-time = "2026-03-08T09:54:25.581Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -120,6 +120,8 @@ export default function PipelineFormComponent({
|
||||
|
||||
// Track unsaved changes by comparing current form values against a saved snapshot
|
||||
const savedSnapshotRef = useRef<string>('');
|
||||
// Track which dynamic form stages have completed their initial mount emission.
|
||||
const initializedStagesRef = useRef<Set<string>>(new Set());
|
||||
const watchedValues = form.watch();
|
||||
const hasUnsavedChanges = useMemo(() => {
|
||||
if (!isEditMode || !savedSnapshotRef.current) return false;
|
||||
@@ -160,6 +162,7 @@ export default function PipelineFormComponent({
|
||||
};
|
||||
form.reset(loadedValues);
|
||||
savedSnapshotRef.current = JSON.stringify(loadedValues);
|
||||
initializedStagesRef.current.clear();
|
||||
});
|
||||
}
|
||||
}, []);
|
||||
@@ -235,6 +238,33 @@ export default function PipelineFormComponent({
|
||||
});
|
||||
}
|
||||
|
||||
// Called from DynamicFormComponent/N8nAuthFormComponent onSubmit callbacks.
|
||||
// On the first emission for a stage (mount-time default filling), the
|
||||
// snapshot is synchronously re-captured so that hasUnsavedChanges stays false.
|
||||
function handleDynamicFormEmit(
|
||||
formName: keyof FormValues,
|
||||
stageName: string,
|
||||
values: object,
|
||||
) {
|
||||
const stageKey = `${String(formName)}.${stageName}`;
|
||||
const isFirstEmission = !initializedStagesRef.current.has(stageKey);
|
||||
|
||||
const currentValues =
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(form.getValues(formName) as Record<string, any>) || {};
|
||||
form.setValue(formName, {
|
||||
...currentValues,
|
||||
[stageName]: values,
|
||||
});
|
||||
|
||||
if (isFirstEmission) {
|
||||
initializedStagesRef.current.add(stageKey);
|
||||
// Synchronously re-capture snapshot so that the useMemo comparison
|
||||
// in the same render cycle still returns false.
|
||||
savedSnapshotRef.current = JSON.stringify(form.getValues());
|
||||
}
|
||||
}
|
||||
|
||||
function renderDynamicForms(
|
||||
stage: PipelineConfigStage,
|
||||
formName: keyof FormValues,
|
||||
@@ -264,13 +294,7 @@ export default function PipelineFormComponent({
|
||||
{}
|
||||
}
|
||||
onSubmit={(values) => {
|
||||
const currentValues =
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(form.getValues(formName) as Record<string, any>) || {};
|
||||
form.setValue(formName, {
|
||||
...currentValues,
|
||||
[stage.name]: values,
|
||||
});
|
||||
handleDynamicFormEmit(formName, stage.name, values);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
@@ -302,13 +326,7 @@ export default function PipelineFormComponent({
|
||||
{}
|
||||
}
|
||||
onSubmit={(values) => {
|
||||
const currentValues =
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(form.getValues(formName) as Record<string, any>) || {};
|
||||
form.setValue(formName, {
|
||||
...currentValues,
|
||||
[stage.name]: values,
|
||||
});
|
||||
handleDynamicFormEmit(formName, stage.name, values);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
@@ -333,13 +351,7 @@ export default function PipelineFormComponent({
|
||||
(form.watch(formName) as Record<string, any>)?.[stage.name] || {}
|
||||
}
|
||||
onSubmit={(values) => {
|
||||
const currentValues =
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(form.getValues(formName) as Record<string, any>) || {};
|
||||
form.setValue(formName, {
|
||||
...currentValues,
|
||||
[stage.name]: values,
|
||||
});
|
||||
handleDynamicFormEmit(formName, stage.name, values);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
Reference in New Issue
Block a user