diff --git a/pyproject.toml b/pyproject.toml index 501ab819..65f8312e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ dependencies = [ "chromadb>=0.4.24", "qdrant-client (>=1.15.1,<2.0.0)", "pyseekdb==1.1.0.post3", - "langbot-plugin==0.3.0rc1", + "langbot-plugin==0.3.0", "asyncpg>=0.30.0", "line-bot-sdk>=3.19.0", "tboxsdk>=0.0.10", diff --git a/src/langbot/pkg/pipeline/config_coercion.py b/src/langbot/pkg/pipeline/config_coercion.py new file mode 100644 index 00000000..649f9051 --- /dev/null +++ b/src/langbot/pkg/pipeline/config_coercion.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + +# metadata type -> coercion function +_COERCE_MAP = { + 'integer': lambda v: int(v), + 'number': lambda v: float(v), + 'float': lambda v: float(v), +} + + +def _coerce_bool(v): + if isinstance(v, bool): + return v + if isinstance(v, str): + if v.lower() == 'true': + return True + if v.lower() == 'false': + return False + raise ValueError(f'Cannot convert string {v!r} to bool') + return bool(v) + + +def _coerce_value(value, expected_type: str): + """Convert a single value to the expected type. + + Returns the converted value, or the original value if no conversion needed. + """ + if value is None: + return value + + if expected_type == 'boolean': + if isinstance(value, bool): + return value + return _coerce_bool(value) + + coerce_fn = _COERCE_MAP.get(expected_type) + if coerce_fn is None: + return value + + # Already the correct type + if expected_type == 'integer' and isinstance(value, int) and not isinstance(value, bool): + return value + if expected_type in ('number', 'float') and isinstance(value, (int, float)) and not isinstance(value, bool): + return float(value) + + return coerce_fn(value) + + +def coerce_pipeline_config( + config: dict, + *metadata_list: dict, +) -> None: + """Coerce pipeline config values according to metadata type definitions. + + Walks each metadata dict (trigger, safety, ai, output) and converts + config values in-place so that strings coming from the JSON column are + cast to their declared types (integer, number/float, boolean). + + Args: + config: The pipeline config dict to modify in-place. + *metadata_list: Metadata dicts loaded from the YAML templates. + """ + for meta in metadata_list: + section_name = meta.get('name') + if not section_name or section_name not in config: + continue + + section = config[section_name] + if not isinstance(section, dict): + continue + + for stage_def in meta.get('stages', []): + stage_name = stage_def.get('name') + if not stage_name or stage_name not in section: + continue + + stage_config = section[stage_name] + if not isinstance(stage_config, dict): + continue + + for field_def in stage_def.get('config', []): + field_name = field_def.get('name') + field_type = field_def.get('type') + if not field_name or not field_type or field_name not in stage_config: + continue + + old_value = stage_config[field_name] + try: + new_value = _coerce_value(old_value, field_type) + if new_value is not old_value: + stage_config[field_name] = new_value + except (ValueError, TypeError) as e: + logger.warning( + 'Failed to coerce config %s.%s.%s (%r) to %s: %s', + section_name, + stage_name, + field_name, + old_value, + field_type, + e, + ) diff --git a/src/langbot/pkg/pipeline/pipelinemgr.py b/src/langbot/pkg/pipeline/pipelinemgr.py index d56f626c..5d0012d1 100644 --- a/src/langbot/pkg/pipeline/pipelinemgr.py +++ b/src/langbot/pkg/pipeline/pipelinemgr.py @@ -13,6 +13,7 @@ import langbot_plugin.api.entities.builtin.platform.message as platform_message import langbot_plugin.api.entities.builtin.platform.events as platform_events import langbot_plugin.api.entities.events as events from ..utils import importutil +from .config_coercion import coerce_pipeline_config import langbot_plugin.api.entities.builtin.provider.session as provider_session import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @@ -420,6 +421,14 @@ class PipelineManager: elif isinstance(pipeline_entity, dict): pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity) + coerce_pipeline_config( + pipeline_entity.config, + getattr(self.ap, 'pipeline_config_meta_trigger', {'name': 'trigger', 'stages': []}), + getattr(self.ap, 'pipeline_config_meta_safety', {'name': 'safety', 'stages': []}), + getattr(self.ap, 'pipeline_config_meta_ai', {'name': 'ai', 'stages': []}), + getattr(self.ap, 'pipeline_config_meta_output', {'name': 'output', 'stages': []}), + ) + # initialize stage containers according to pipeline_entity.stages stage_containers: list[StageInstContainer] = [] for stage_name in pipeline_entity.stages: diff --git a/src/langbot/pkg/platform/botmgr.py b/src/langbot/pkg/platform/botmgr.py index ef40e3ef..44874cfb 100644 --- a/src/langbot/pkg/platform/botmgr.py +++ b/src/langbot/pkg/platform/botmgr.py @@ -282,6 +282,8 @@ class PlatformManager: return runtime_bot async def get_bot_by_uuid(self, bot_uuid: str) -> RuntimeBot | None: + if self.websocket_proxy_bot and self.websocket_proxy_bot.bot_entity.uuid == bot_uuid: + return self.websocket_proxy_bot for bot in self.bots: if bot.bot_entity.uuid == bot_uuid: return bot diff --git a/src/langbot/pkg/platform/sources/websocket_adapter.py b/src/langbot/pkg/platform/sources/websocket_adapter.py index 238276ee..01da9f10 100644 --- a/src/langbot/pkg/platform/sources/websocket_adapter.py +++ b/src/langbot/pkg/platform/sources/websocket_adapter.py @@ -37,16 +37,24 @@ class WebSocketSession: id: str message_lists: dict[str, list[WebSocketMessage]] = {} """消息列表 {pipeline_uuid: [messages]}""" + stream_message_indexes: dict[str, dict[str, int]] = {} + """流式消息索引 {pipeline_uuid: {resp_message_id: message_index}}""" def __init__(self, id: str): self.id = id self.message_lists = {} + self.stream_message_indexes = {} def get_message_list(self, pipeline_uuid: str) -> list[WebSocketMessage]: if pipeline_uuid not in self.message_lists: self.message_lists[pipeline_uuid] = [] return self.message_lists[pipeline_uuid] + def get_stream_message_indexes(self, pipeline_uuid: str) -> dict[str, int]: + if pipeline_uuid not in self.stream_message_indexes: + self.stream_message_indexes[pipeline_uuid] = {} + return self.stream_message_indexes[pipeline_uuid] + class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): """WebSocket适配器 - 支持双向实时通信""" @@ -89,20 +97,46 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter) target_id: str, message: platform_message.MessageChain, ) -> dict: - """发送消息 - 这里用于主动推送消息到前端""" - message_data = { - 'type': 'bot_message', - 'target_type': target_type, - 'target_id': target_id, - 'content': str(message), - 'message_chain': [component.__dict__ for component in message], - 'timestamp': datetime.now().isoformat(), - } + """发送消息 - 这里用于主动推送消息到前端 - # 推送到所有相关连接 - await self.outbound_message_queue.put(message_data) + 对于 WebSocket 适配器,我们需要将消息广播到正确的 pipeline 连接。 + target_id 可能是 launcher_id(如 websocket_xxx)或 pipeline_uuid。 + 我们需要尝试两种方式来确保消息能够送达。 + """ + # 获取当前的 pipeline_uuid + pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid + session_type = 'group' if target_type == 'group' else 'person' - return message_data + # 选择会话 + session = self.websocket_group_session if session_type == 'group' else self.websocket_person_session + + # 生成唯一消息ID + msg_id = len(session.get_message_list(pipeline_uuid)) + 1 + + message_data = WebSocketMessage( + id=msg_id, + role='assistant', + content=str(message), + message_chain=[component.__dict__ for component in message], + timestamp=datetime.now().isoformat(), + is_final=True, + ) + + # 保存到历史记录 + session.get_message_list(pipeline_uuid).append(message_data) + + # 直接广播到当前pipeline的连接 + await ws_connection_manager.broadcast_to_pipeline( + pipeline_uuid, + { + 'type': 'response', + 'session_type': session_type, + 'data': message_data.model_dump(), + }, + session_type=session_type, + ) + + return message_data.model_dump() async def reply_message( self, @@ -169,10 +203,16 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter) pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person' message_list = session.get_message_list(pipeline_uuid) + stream_message_indexes = session.get_stream_message_indexes(pipeline_uuid) - # 检查是否是新的流式消息(通过bot_message对象判断) - # 如果列表为空,或者最后一条消息已经is_final=True,则创建新消息 - if not message_list or message_list[-1].is_final: + # Streaming messages in LangBot have a stable resp_message_id during the same assistant reply. + # Use it as the primary key to avoid overwriting an old card from a previous reply. + resp_message_id = str(getattr(bot_message, 'resp_message_id', '') or '') + existing_index = stream_message_indexes.get(resp_message_id) if resp_message_id else None + + message_is_final = is_final and bot_message.tool_calls is None + + if existing_index is None or existing_index >= len(message_list): # 创建新消息 msg_id = len(message_list) + 1 message_data = WebSocketMessage( @@ -181,27 +221,31 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter) content=str(message), message_chain=[component.__dict__ for component in message], timestamp=datetime.now().isoformat(), - is_final=is_final and bot_message.tool_calls is None, + is_final=message_is_final, ) - # 只有在is_final时才保存到历史记录 - if is_final and bot_message.tool_calls is None: - message_list.append(message_data) + # 立即添加到历史记录(即使is_final=False),以便后续块可以更新它 + message_list.append(message_data) + if resp_message_id: + stream_message_indexes[resp_message_id] = len(message_list) - 1 else: - # 更新最后一条消息 - msg_id = message_list[-1].id + # 更新同一条流式消息 + old_message = message_list[existing_index] + msg_id = old_message.id message_data = WebSocketMessage( id=msg_id, role='assistant', content=str(message), message_chain=[component.__dict__ for component in message], - timestamp=message_list[-1].timestamp, # 保持原始时间戳 - is_final=is_final and bot_message.tool_calls is None, + timestamp=old_message.timestamp, # 保持原始时间戳 + is_final=message_is_final, ) - # 如果是final,更新历史记录中的最后一条 - if is_final and bot_message.tool_calls is None: - message_list[-1] = message_data + # 更新历史记录中的对应消息 + message_list[existing_index] = message_data + + if message_is_final and resp_message_id: + stream_message_indexes.pop(resp_message_id, None) # 直接广播到所有该pipeline的连接,包含session_type信息 await ws_connection_manager.broadcast_to_pipeline( @@ -410,6 +454,10 @@ class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter) if session_type == 'person': if pipeline_uuid in self.websocket_person_session.message_lists: self.websocket_person_session.message_lists[pipeline_uuid] = [] + if pipeline_uuid in self.websocket_person_session.stream_message_indexes: + self.websocket_person_session.stream_message_indexes[pipeline_uuid] = {} else: if pipeline_uuid in self.websocket_group_session.message_lists: self.websocket_group_session.message_lists[pipeline_uuid] = [] + if pipeline_uuid in self.websocket_group_session.stream_message_indexes: + self.websocket_group_session.stream_message_indexes[pipeline_uuid] = {} diff --git a/src/langbot/pkg/plugin/handler.py b/src/langbot/pkg/plugin/handler.py index dbe4698c..30464312 100644 --- a/src/langbot/pkg/plugin/handler.py +++ b/src/langbot/pkg/plugin/handler.py @@ -337,7 +337,14 @@ class RuntimeConnectionHandler(handler.Handler): ) messages_obj = [provider_message.Message.model_validate(message) for message in messages] - funcs_obj = [resource_tool.LLMTool.model_validate(func) for func in funcs] + + # The func field is excluded during model_dump() in plugin side (marked as exclude=True), + # but it's a required field for LLMTool validation. We need to provide a placeholder + # function when reconstructing the LLMTool objects from serialized data. + async def _placeholder_func(**kwargs): + pass + + funcs_obj = [resource_tool.LLMTool.model_validate({**func, 'func': _placeholder_func}) for func in funcs] result = await llm_model.provider.invoke_llm( query=None, diff --git a/tests/unit_tests/pipeline/test_config_coercion.py b/tests/unit_tests/pipeline/test_config_coercion.py new file mode 100644 index 00000000..a23f54de --- /dev/null +++ b/tests/unit_tests/pipeline/test_config_coercion.py @@ -0,0 +1,113 @@ +"""Unit tests for config_coercion module""" + +from __future__ import annotations + +import pytest + +from langbot.pkg.pipeline.config_coercion import _coerce_value, coerce_pipeline_config + + +class TestCoerceValue: + """Tests for _coerce_value function""" + + def test_none_passthrough(self): + assert _coerce_value(None, 'integer') is None + assert _coerce_value(None, 'boolean') is None + + def test_string_to_integer(self): + assert _coerce_value('120', 'integer') == 120 + assert _coerce_value('0', 'integer') == 0 + assert _coerce_value('-5', 'integer') == -5 + + def test_integer_passthrough(self): + assert _coerce_value(42, 'integer') == 42 + + def test_string_to_float(self): + assert _coerce_value('3.14', 'number') == 3.14 + assert _coerce_value('3.14', 'float') == 3.14 + + def test_int_to_float(self): + assert _coerce_value(3, 'number') == 3.0 + assert isinstance(_coerce_value(3, 'number'), float) + + def test_float_passthrough(self): + assert _coerce_value(3.14, 'float') == 3.14 + + def test_string_to_bool(self): + assert _coerce_value('true', 'boolean') is True + assert _coerce_value('True', 'boolean') is True + assert _coerce_value('false', 'boolean') is False + assert _coerce_value('False', 'boolean') is False + + def test_bool_passthrough(self): + assert _coerce_value(True, 'boolean') is True + assert _coerce_value(False, 'boolean') is False + + def test_invalid_bool_string_raises(self): + with pytest.raises(ValueError): + _coerce_value('notabool', 'boolean') + + def test_unknown_type_passthrough(self): + assert _coerce_value('hello', 'string') == 'hello' + assert _coerce_value('hello', 'unknown') == 'hello' + + def test_invalid_integer_raises(self): + with pytest.raises(ValueError): + _coerce_value('abc', 'integer') + + +class TestCoercePipelineConfig: + """Tests for coerce_pipeline_config function""" + + def _make_meta(self, section_name: str, stage_name: str, fields: list[dict]) -> dict: + return { + 'name': section_name, + 'stages': [{'name': stage_name, 'config': fields}], + } + + def test_coerce_integer_in_config(self): + config = {'trigger': {'misc': {'timeout': '120'}}} + meta = self._make_meta('trigger', 'misc', [{'name': 'timeout', 'type': 'integer'}]) + coerce_pipeline_config(config, meta) + assert config['trigger']['misc']['timeout'] == 120 + + def test_coerce_boolean_in_config(self): + config = {'output': {'misc': {'at-sender': 'true'}}} + meta = self._make_meta('output', 'misc', [{'name': 'at-sender', 'type': 'boolean'}]) + coerce_pipeline_config(config, meta) + assert config['output']['misc']['at-sender'] is True + + def test_missing_section_skipped(self): + config = {'ai': {}} + meta = self._make_meta('trigger', 'misc', [{'name': 'x', 'type': 'integer'}]) + coerce_pipeline_config(config, meta) # should not raise + + def test_missing_field_skipped(self): + config = {'trigger': {'misc': {}}} + meta = self._make_meta('trigger', 'misc', [{'name': 'nonexistent', 'type': 'integer'}]) + coerce_pipeline_config(config, meta) # should not raise + + def test_invalid_value_logs_warning(self, caplog): + config = {'trigger': {'misc': {'timeout': 'abc'}}} + meta = self._make_meta('trigger', 'misc', [{'name': 'timeout', 'type': 'integer'}]) + import logging + + with caplog.at_level(logging.WARNING): + coerce_pipeline_config(config, meta) + assert config['trigger']['misc']['timeout'] == 'abc' # unchanged + assert 'Failed to coerce' in caplog.text + + def test_empty_metadata(self): + config = {'trigger': {'misc': {'timeout': '120'}}} + coerce_pipeline_config(config) # no metadata args, should not raise + + def test_multiple_metadata(self): + config = { + 'trigger': {'misc': {'timeout': '120'}}, + 'output': {'misc': {'at-sender': 'false'}}, + } + meta_trigger = self._make_meta('trigger', 'misc', [{'name': 'timeout', 'type': 'integer'}]) + meta_output = self._make_meta('output', 'misc', [{'name': 'at-sender', 'type': 'boolean'}]) + coerce_pipeline_config(config, meta_trigger, meta_output) + assert config['trigger']['misc']['timeout'] == 120 + assert config['output']['misc']['at-sender'] is False diff --git a/uv.lock b/uv.lock index 47175a1d..60cdb712 100644 --- a/uv.lock +++ b/uv.lock @@ -1937,7 +1937,7 @@ requires-dist = [ { name = "ebooklib", specifier = ">=0.18" }, { name = "gewechat-client", specifier = ">=0.1.5" }, { name = "html2text", specifier = ">=2024.2.26" }, - { name = "langbot-plugin", specifier = "==0.3.0rc1" }, + { name = "langbot-plugin", specifier = "==0.3.0" }, { name = "langchain", specifier = ">=0.2.0" }, { name = "langchain-text-splitters", specifier = ">=0.0.1" }, { name = "lark-oapi", specifier = ">=1.4.15" }, @@ -1993,7 +1993,7 @@ dev = [ [[package]] name = "langbot-plugin" -version = "0.3.0rc1" +version = "0.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiofiles" }, @@ -2011,9 +2011,9 @@ dependencies = [ { name = "watchdog" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/69/42/1bc1d50562182ca325960edeb9600f576c90b3352f2e5e19d11bb7d28d30/langbot_plugin-0.3.0rc1.tar.gz", hash = "sha256:0ecf0e646ea07aee9fb99d8283337b0926de8322b012c8e3a514ba54a4530598", size = 169886, upload-time = "2026-03-05T14:55:02.131Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8d/e5/3686b3225e5f2ee6e19a6050bb981b49a91f2450dff83deb5dfba13b3a2a/langbot_plugin-0.3.0.tar.gz", hash = "sha256:9add2d6e81c8cc7281863e4a92a33ed6228dcc0243f4327ac4062edc962dbf98", size = 169751, upload-time = "2026-03-08T09:54:27.102Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/77/67/048798d05dbfbffcb28093cbfb986ab81a0f8ee3db0daf9235ad2357c3ed/langbot_plugin-0.3.0rc1-py3-none-any.whl", hash = "sha256:d72991ecd527c9c1b1ec1526b40b67369bcdb89a79b277920c261de76fd069d8", size = 144141, upload-time = "2026-03-05T14:55:03.191Z" }, + { url = "https://files.pythonhosted.org/packages/72/51/18f0c1446bcb6712ff3d31d81ea708e3f0e671fde5da69598204a1df977d/langbot_plugin-0.3.0-py3-none-any.whl", hash = "sha256:37bfd3ce507448a6ec4444bec1bc6da1c9911c9df144dfd428febb71122077a6", size = 144096, upload-time = "2026-03-08T09:54:25.581Z" }, ] [[package]] diff --git a/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx b/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx index 517d0f76..392430d3 100644 --- a/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx +++ b/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx @@ -120,6 +120,8 @@ export default function PipelineFormComponent({ // Track unsaved changes by comparing current form values against a saved snapshot const savedSnapshotRef = useRef(''); + // Track which dynamic form stages have completed their initial mount emission. + const initializedStagesRef = useRef>(new Set()); const watchedValues = form.watch(); const hasUnsavedChanges = useMemo(() => { if (!isEditMode || !savedSnapshotRef.current) return false; @@ -160,6 +162,7 @@ export default function PipelineFormComponent({ }; form.reset(loadedValues); savedSnapshotRef.current = JSON.stringify(loadedValues); + initializedStagesRef.current.clear(); }); } }, []); @@ -235,6 +238,33 @@ export default function PipelineFormComponent({ }); } + // Called from DynamicFormComponent/N8nAuthFormComponent onSubmit callbacks. + // On the first emission for a stage (mount-time default filling), the + // snapshot is synchronously re-captured so that hasUnsavedChanges stays false. + function handleDynamicFormEmit( + formName: keyof FormValues, + stageName: string, + values: object, + ) { + const stageKey = `${String(formName)}.${stageName}`; + const isFirstEmission = !initializedStagesRef.current.has(stageKey); + + const currentValues = + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (form.getValues(formName) as Record) || {}; + form.setValue(formName, { + ...currentValues, + [stageName]: values, + }); + + if (isFirstEmission) { + initializedStagesRef.current.add(stageKey); + // Synchronously re-capture snapshot so that the useMemo comparison + // in the same render cycle still returns false. + savedSnapshotRef.current = JSON.stringify(form.getValues()); + } + } + function renderDynamicForms( stage: PipelineConfigStage, formName: keyof FormValues, @@ -264,13 +294,7 @@ export default function PipelineFormComponent({ {} } onSubmit={(values) => { - const currentValues = - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (form.getValues(formName) as Record) || {}; - form.setValue(formName, { - ...currentValues, - [stage.name]: values, - }); + handleDynamicFormEmit(formName, stage.name, values); }} /> @@ -302,13 +326,7 @@ export default function PipelineFormComponent({ {} } onSubmit={(values) => { - const currentValues = - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (form.getValues(formName) as Record) || {}; - form.setValue(formName, { - ...currentValues, - [stage.name]: values, - }); + handleDynamicFormEmit(formName, stage.name, values); }} /> @@ -333,13 +351,7 @@ export default function PipelineFormComponent({ (form.watch(formName) as Record)?.[stage.name] || {} } onSubmit={(values) => { - const currentValues = - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (form.getValues(formName) as Record) || {}; - form.setValue(formName, { - ...currentValues, - [stage.name]: values, - }); + handleDynamicFormEmit(formName, stage.name, values); }} />