From 7fb3cfa638be4c4fd53f064cee629a3304de1b78 Mon Sep 17 00:00:00 2001 From: huanghuoguoguo <1051233107@qq.com> Date: Sat, 6 Jun 2026 00:21:19 +0800 Subject: [PATCH] refactor(provider): simplify litellm capabilities --- src/langbot/pkg/pipeline/preproc/preproc.py | 8 +- src/langbot/pkg/provider/modelmgr/modelmgr.py | 32 +++- .../modelmgr/requesters/litellmchat.py | 162 +++++++++++++++-- .../pkg/provider/runners/localagent.py | 169 ++++++++---------- tests/unit_tests/provider/test_litellmchat.py | 146 ++++++++++++++- .../provider/test_localagent_sandbox_exec.py | 41 ++++- .../components/AddModelPopover.tsx | 26 --- 7 files changed, 443 insertions(+), 141 deletions(-) diff --git a/src/langbot/pkg/pipeline/preproc/preproc.py b/src/langbot/pkg/pipeline/preproc/preproc.py index 8aa15750..e180b0d2 100644 --- a/src/langbot/pkg/pipeline/preproc/preproc.py +++ b/src/langbot/pkg/pipeline/preproc/preproc.py @@ -109,7 +109,7 @@ class PreProcessor(stage.PipelineStage): if llm_model: query.use_llm_model_uuid = llm_model.model_entity.uuid - if llm_model.model_entity.abilities.__contains__('func_call'): + if 'func_call' in (llm_model.model_entity.abilities or []): # Get bound plugins and MCP servers for filtering tools bound_plugins = query.variables.get('_pipeline_bound_plugins', None) bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None) @@ -162,7 +162,7 @@ class PreProcessor(stage.PipelineStage): if ( selected_runner == 'local-agent' and llm_model - and not llm_model.model_entity.abilities.__contains__('vision') + and 'vision' not in (llm_model.model_entity.abilities or []) ): for msg in query.messages: if isinstance(msg.content, list): @@ -181,7 +181,7 @@ class PreProcessor(stage.PipelineStage): plain_text += me.text elif isinstance(me, platform_message.Image): if selected_runner != 'local-agent' or ( - llm_model and llm_model.model_entity.abilities.__contains__('vision') + llm_model and 'vision' in (llm_model.model_entity.abilities or []) ): if me.base64 is not None: content_list.append(provider_message.ContentElement.from_image_base64(me.base64)) @@ -202,7 +202,7 @@ class PreProcessor(stage.PipelineStage): content_list.append(provider_message.ContentElement.from_text(msg.text)) elif isinstance(msg, platform_message.Image): if selected_runner != 'local-agent' or ( - llm_model and llm_model.model_entity.abilities.__contains__('vision') + llm_model and 'vision' in (llm_model.model_entity.abilities or []) ): if msg.base64 is not None: content_list.append(provider_message.ContentElement.from_image_base64(msg.base64)) diff --git a/src/langbot/pkg/provider/modelmgr/modelmgr.py b/src/langbot/pkg/provider/modelmgr/modelmgr.py index ee7f4d3b..3609f147 100644 --- a/src/langbot/pkg/provider/modelmgr/modelmgr.py +++ b/src/langbot/pkg/provider/modelmgr/modelmgr.py @@ -37,16 +37,39 @@ class ModelManager: self.requester_components = [] self.requester_dict = {} + @staticmethod + def _get_litellm_provider_from_manifest(component: engine.Component | None) -> str | None: + if component is None: + return None + + spec = getattr(component, 'spec', None) or {} + litellm_provider = None + + if isinstance(spec, dict): + litellm_provider = spec.get('litellm_provider') + else: + getter = getattr(spec, 'get', None) + if callable(getter): + try: + litellm_provider = getter('litellm_provider') + except Exception: + litellm_provider = None + + if isinstance(litellm_provider, str) and litellm_provider: + return litellm_provider + return None + async def initialize(self): self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester') requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {} for component in self.requester_components: # Skip components that use litellm_provider (they will use litellmchat.py instead) - if component.spec.get('litellm_provider'): + litellm_provider = self._get_litellm_provider_from_manifest(component) + if litellm_provider: self.ap.logger.debug( f'Skipping Python class loading for {component.metadata.name} ' - f'(uses litellm_provider={component.spec.get("litellm_provider")})' + f'(uses litellm_provider={litellm_provider})' ) continue requester_dict[component.metadata.name] = component.get_python_component_class() @@ -303,17 +326,18 @@ class ModelManager: # Get requester manifest to check for litellm_provider requester_manifest = self.get_available_requester_manifest_by_name(provider_entity.requester) + litellm_provider = self._get_litellm_provider_from_manifest(requester_manifest) # Build config from base_url config = {'base_url': provider_entity.base_url} # Check if requester manifest specifies litellm_provider - if requester_manifest and requester_manifest.spec.get('litellm_provider'): + if litellm_provider: from .requesters import litellmchat # Use unified LiteLLMRequester with provider prefix # Map litellm_provider (YAML spec) to custom_llm_provider (config) - config['custom_llm_provider'] = requester_manifest.spec['litellm_provider'] + config['custom_llm_provider'] = litellm_provider requester_inst = litellmchat.LiteLLMRequester( ap=self.ap, config=config, diff --git a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py index e0eb3eb8..236d4723 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py @@ -16,6 +16,9 @@ import langbot_plugin.api.entities.builtin.provider.message as provider_message class LiteLLMRequester(requester.ProviderAPIRequester): """LiteLLM unified API requester supporting chat, embedding, and rerank.""" + _EMBEDDING_MODEL_HINTS = ('embedding', 'embed', 'bge-', 'e5-', 'm3e', 'gte-', 'text-embedding') + _RERANK_MODEL_HINTS = ('rerank', 're-rank', 're_rank') + default_config: dict[str, typing.Any] = { 'base_url': '', 'timeout': 120, @@ -36,10 +39,90 @@ class LiteLLMRequester(requester.ProviderAPIRequester): provider = custom_llm_provider or self.requester_cfg.get('custom_llm_provider', '') if provider: # LiteLLM format: provider/model_name + if model_name.startswith(f'{provider}/'): + return model_name return f'{provider}/{model_name}' # If no custom provider, assume model_name already includes prefix or is OpenAI-compatible return model_name + def _get_custom_llm_provider(self) -> str | None: + return self.requester_cfg.get('custom_llm_provider') or None + + def _safe_litellm_bool_helper(self, helper_name: str, model_name: str) -> bool: + """Call a LiteLLM boolean capability helper without letting metadata gaps fail requests.""" + helper = getattr(litellm, helper_name, None) + if not callable(helper): + return False + + provider = self._get_custom_llm_provider() + candidates: list[tuple[str, str | None]] = [(model_name, provider)] + litellm_model_name = self._build_litellm_model_name(model_name) + if litellm_model_name != model_name: + candidates.append((litellm_model_name, None)) + + for candidate_model, candidate_provider in candidates: + try: + if bool(helper(model=candidate_model, custom_llm_provider=candidate_provider)): + return True + except Exception: + continue + return False + + def _safe_context_length(self, model_name: str) -> int | None: + helper = getattr(litellm, 'get_max_tokens', None) + if not callable(helper): + return None + + candidates = [model_name] + litellm_model_name = self._build_litellm_model_name(model_name) + if litellm_model_name != model_name: + candidates.append(litellm_model_name) + + for candidate in candidates: + try: + max_tokens = helper(candidate) + except Exception: + continue + if isinstance(max_tokens, int) and max_tokens > 0: + return max_tokens + return None + + def _supports_function_calling(self, model_name: str) -> bool: + return self._safe_litellm_bool_helper('supports_function_calling', model_name) + + def _supports_vision(self, model_name: str) -> bool: + return self._safe_litellm_bool_helper('supports_vision', model_name) + + def _infer_model_type(self, model_id: str) -> str: + normalized_id = (model_id or '').lower() + if any(kw in normalized_id for kw in self._RERANK_MODEL_HINTS): + return 'rerank' + if any(kw in normalized_id for kw in self._EMBEDDING_MODEL_HINTS): + return 'embedding' + return 'llm' + + def _enrich_scanned_model(self, model_id: str) -> dict[str, typing.Any]: + model_type = self._infer_model_type(model_id) + scanned_model: dict[str, typing.Any] = { + 'id': model_id, + 'name': model_id, + 'type': model_type, + } + + if model_type == 'llm': + abilities = [] + if self._supports_function_calling(model_id): + abilities.append('func_call') + if self._supports_vision(model_id): + abilities.append('vision') + scanned_model['abilities'] = abilities + + context_length = self._safe_context_length(model_id) + if context_length is not None: + scanned_model['context_length'] = context_length + + return scanned_model + def _convert_messages(self, messages: typing.List[provider_message.Message]) -> list[dict]: """Convert LangBot messages to LiteLLM/OpenAI format.""" req_messages = [] @@ -121,6 +204,64 @@ class LiteLLMRequester(requester.ProviderAPIRequester): """Extract usage info from a non-streaming LiteLLM response.""" return self._normalize_usage(getattr(response, 'usage', None)) + @staticmethod + def _as_dict(value: typing.Any) -> dict: + if value is None: + return {} + if isinstance(value, dict): + return value + if hasattr(value, 'model_dump'): + return value.model_dump() + return {} + + def _normalize_stream_tool_calls( + self, + raw_tool_calls: typing.Any, + tool_call_state: dict[int, dict[str, str]], + ) -> list[dict] | None: + """Fill OpenAI-style streaming tool-call deltas so MessageChunk can validate them.""" + if not raw_tool_calls: + return None + + normalized = [] + for fallback_index, raw_tool_call in enumerate(raw_tool_calls): + tool_call = self._as_dict(raw_tool_call) + index = tool_call.get('index') + if not isinstance(index, int): + index = fallback_index + + state = tool_call_state.setdefault(index, {'id': '', 'type': 'function', 'name': ''}) + if tool_call.get('id'): + state['id'] = tool_call['id'] + if tool_call.get('type'): + state['type'] = tool_call['type'] + + function = self._as_dict(tool_call.get('function')) + if function.get('name'): + state['name'] = function['name'] + + arguments = function.get('arguments') + if arguments is None: + arguments = '' + elif not isinstance(arguments, str): + arguments = str(arguments) + + if not state['id'] or not state['name']: + continue + + normalized.append( + { + 'id': state['id'], + 'type': state['type'] or 'function', + 'function': { + 'name': state['name'], + 'arguments': arguments, + }, + } + ) + + return normalized or None + def _build_common_args(self, args: dict, include_retry_params: bool = True) -> dict: """Apply common requester config to args dict.""" if self.requester_cfg.get('base_url'): @@ -189,6 +330,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester): tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs) if tools: args['tools'] = tools + args.setdefault('tool_choice', 'auto') return args @@ -240,6 +382,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester): chunk_idx = 0 role = 'assistant' + tool_call_state: dict[int, dict[str, str]] = {} try: response = await acompletion(**args) @@ -283,14 +426,16 @@ class LiteLLMRequester(requester.ProviderAPIRequester): # Use reasoning_content as the displayed content delta_content = reasoning_content - if chunk_idx == 0 and not delta_content and not delta.get('tool_calls'): + tool_calls = self._normalize_stream_tool_calls(delta.get('tool_calls'), tool_call_state) + + if chunk_idx == 0 and not delta_content and not tool_calls: chunk_idx += 1 continue chunk_data = { 'role': role, 'content': delta_content if delta_content else None, - 'tool_calls': delta.get('tool_calls'), + 'tool_calls': tool_calls, 'is_final': bool(finish_reason), } @@ -412,18 +557,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester): if not model_id: continue - # Infer model type - normalized_id = (model_id or '').lower() - embedding_keywords = ('embedding', 'embed', 'bge-', 'e5-', 'm3e', 'gte-', 'text-embedding') - model_type = 'embedding' if any(kw in normalized_id for kw in embedding_keywords) else 'llm' - - models.append( - { - 'id': model_id, - 'name': model_id, - 'type': model_type, - } - ) + models.append(self._enrich_scanned_model(model_id)) models.sort(key=lambda x: (x['type'] != 'llm', x['name'].lower())) diff --git a/src/langbot/pkg/provider/runners/localagent.py b/src/langbot/pkg/provider/runners/localagent.py index 28d014d0..11c56699 100644 --- a/src/langbot/pkg/provider/runners/localagent.py +++ b/src/langbot/pkg/provider/runners/localagent.py @@ -41,6 +41,64 @@ SANDBOX_EXEC_SYSTEM_GUIDANCE = ( MAX_TOOL_CALL_ROUNDS = 128 +def _model_has_ability(model: modelmgr_requester.RuntimeLLMModel, ability: str) -> bool: + return ability in (model.model_entity.abilities or []) + + +class _StreamAccumulator: + """Accumulate streamed content and fragmented OpenAI-style tool calls.""" + + def __init__(self, msg_sequence: int = 0, initial_content: str | None = None): + self.tool_calls_map: dict[str, provider_message.ToolCall] = {} + self.msg_idx = 0 + self.accumulated_content = initial_content or '' + self.last_role = 'assistant' + self.msg_sequence = msg_sequence + + def add(self, msg: provider_message.MessageChunk) -> provider_message.MessageChunk | None: + self.msg_idx += 1 + + if msg.role: + self.last_role = msg.role + + if msg.content: + self.accumulated_content += msg.content + + if msg.tool_calls: + for tool_call in msg.tool_calls: + if tool_call.id not in self.tool_calls_map: + self.tool_calls_map[tool_call.id] = provider_message.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=provider_message.FunctionCall( + name=tool_call.function.name if tool_call.function else '', + arguments='', + ), + ) + if tool_call.function and tool_call.function.arguments: + self.tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + + if self.msg_idx % 8 == 0 or msg.is_final: + self.msg_sequence += 1 + return provider_message.MessageChunk( + role=self.last_role, + content=self.accumulated_content, + tool_calls=list(self.tool_calls_map.values()) if (self.tool_calls_map and msg.is_final) else None, + is_final=msg.is_final, + msg_sequence=self.msg_sequence, + ) + + return None + + def final_message(self) -> provider_message.MessageChunk: + return provider_message.MessageChunk( + role=self.last_role, + content=self.accumulated_content, + tool_calls=list(self.tool_calls_map.values()) if self.tool_calls_map else None, + msg_sequence=self.msg_sequence, + ) + + @runner.runner_class('local-agent') class LocalAgentRunner(runner.RequestRunner): """Local agent request runner""" @@ -105,7 +163,7 @@ class LocalAgentRunner(runner.RequestRunner): query, model, messages, - funcs if model.model_entity.abilities.__contains__('func_call') else [], + funcs if _model_has_ability(model, 'func_call') else [], extra_args=model.model_entity.extra_args, remove_think=remove_think, ) @@ -135,7 +193,7 @@ class LocalAgentRunner(runner.RequestRunner): query, model, messages, - funcs if model.model_entity.abilities.__contains__('func_call') else [], + funcs if _model_has_ability(model, 'func_call') else [], extra_args=model.model_entity.extra_args, remove_think=remove_think, ) @@ -302,11 +360,7 @@ class LocalAgentRunner(runner.RequestRunner): final_msg = msg else: # Streaming: invoke with fallback - tool_calls_map: dict[str, provider_message.ToolCall] = {} - msg_idx = 0 - accumulated_content = '' - last_role = 'assistant' - msg_sequence = 1 + stream_accumulator = _StreamAccumulator(msg_sequence=1) stream_src, use_llm_model = await self._invoke_stream_with_fallback( query, @@ -316,44 +370,12 @@ class LocalAgentRunner(runner.RequestRunner): remove_think, ) async for msg in stream_src: - msg_idx = msg_idx + 1 - - if msg.role: - last_role = msg.role - - if msg.content: - accumulated_content += msg.content - - if msg.tool_calls: - for tool_call in msg.tool_calls: - if tool_call.id not in tool_calls_map: - tool_calls_map[tool_call.id] = provider_message.ToolCall( - id=tool_call.id, - type=tool_call.type, - function=provider_message.FunctionCall( - name=tool_call.function.name if tool_call.function else '', arguments='' - ), - ) - if tool_call.function and tool_call.function.arguments: - tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments - - if msg_idx % 8 == 0 or msg.is_final: - msg_sequence += 1 - yield provider_message.MessageChunk( - role=last_role, - content=accumulated_content, - tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None, - is_final=msg.is_final, - msg_sequence=msg_sequence, - ) + chunk = stream_accumulator.add(msg) + if chunk: + yield chunk initial_response_emitted = True - final_msg = provider_message.MessageChunk( - role=last_role, - content=accumulated_content, - tool_calls=list(tool_calls_map.values()) if tool_calls_map else None, - msg_sequence=msg_sequence, - ) + final_msg = stream_accumulator.final_message() pending_tool_calls = final_msg.tool_calls first_content = final_msg.content @@ -438,69 +460,36 @@ class LocalAgentRunner(runner.RequestRunner): ) if is_stream: - tool_calls_map = {} - msg_idx = 0 - accumulated_content = '' - last_role = 'assistant' - msg_sequence = first_end_sequence + stream_accumulator = _StreamAccumulator( + msg_sequence=first_end_sequence, + initial_content=first_content, + ) tool_stream_src = use_llm_model.provider.invoke_llm_stream( query, use_llm_model, req_messages, - query.use_funcs if use_llm_model.model_entity.abilities.__contains__('func_call') else [], + query.use_funcs + if _model_has_ability(use_llm_model, 'func_call') + else [], extra_args=use_llm_model.model_entity.extra_args, remove_think=remove_think, ) async for msg in tool_stream_src: - msg_idx += 1 + chunk = stream_accumulator.add(msg) + if chunk: + yield chunk - if msg.role: - last_role = msg.role - - # Prepend first-round content on first chunk of tool-call round - if msg_idx == 1: - accumulated_content = first_content if first_content is not None else accumulated_content - - if msg.content: - accumulated_content += msg.content - - if msg.tool_calls: - for tool_call in msg.tool_calls: - if tool_call.id not in tool_calls_map: - tool_calls_map[tool_call.id] = provider_message.ToolCall( - id=tool_call.id, - type=tool_call.type, - function=provider_message.FunctionCall( - name=tool_call.function.name if tool_call.function else '', arguments='' - ), - ) - if tool_call.function and tool_call.function.arguments: - tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments - - if msg_idx % 8 == 0 or msg.is_final: - msg_sequence += 1 - yield provider_message.MessageChunk( - role=last_role, - content=accumulated_content, - tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None, - is_final=msg.is_final, - msg_sequence=msg_sequence, - ) - - final_msg = provider_message.MessageChunk( - role=last_role, - content=accumulated_content, - tool_calls=list(tool_calls_map.values()) if tool_calls_map else None, - msg_sequence=msg_sequence, - ) + final_msg = stream_accumulator.final_message() else: # Non-streaming: use committed model directly (no fallback in tool loop) msg = await use_llm_model.provider.invoke_llm( query, use_llm_model, req_messages, - query.use_funcs if use_llm_model.model_entity.abilities.__contains__('func_call') else [], + query.use_funcs + if _model_has_ability(use_llm_model, 'func_call') + else [], extra_args=use_llm_model.model_entity.extra_args, remove_think=remove_think, ) diff --git a/tests/unit_tests/provider/test_litellmchat.py b/tests/unit_tests/provider/test_litellmchat.py index ad8d9fd3..f44ba4ba 100644 --- a/tests/unit_tests/provider/test_litellmchat.py +++ b/tests/unit_tests/provider/test_litellmchat.py @@ -68,6 +68,12 @@ class TestBuildLiteLLMModelName: result = requester._build_litellm_model_name('gpt-4o') assert result == 'openai/gpt-4o' + def test_avoid_duplicate_provider_prefix(self): + """Test model name with an existing matching provider prefix.""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'}) + result = requester._build_litellm_model_name('openai/gpt-4o') + assert result == 'openai/gpt-4o' + def test_override_provider(self): """Test override provider via parameter""" requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'}) @@ -151,7 +157,7 @@ class TestInvokeLLMStreamUsage: calls record 0 tokens. """ - def _make_chunk(self, *, content=None, finish_reason=None, usage=None, has_choice=True): + def _make_chunk(self, *, content=None, tool_calls=None, finish_reason=None, usage=None, has_choice=True): chunk = Mock() if usage is not None: chunk.usage = usage @@ -161,7 +167,7 @@ class TestInvokeLLMStreamUsage: choice = Mock() delta = Mock() delta.model_dump = Mock( - return_value={'role': 'assistant', 'content': content, 'tool_calls': None} + return_value={'role': 'assistant', 'content': content, 'tool_calls': tool_calls} ) choice.delta = delta choice.finish_reason = finish_reason @@ -250,6 +256,78 @@ class TestInvokeLLMStreamUsage: assert query.variables['_stream_usage']['total_tokens'] == 12 + @pytest.mark.asyncio + async def test_stream_tool_call_delta_missing_id_and_name(self): + """LiteLLM may stream tool-call argument deltas with id/name set to None.""" + import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + mock_ap = Mock() + mock_ap.tool_mgr = Mock() + mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock( + return_value=[{'type': 'function', 'function': {'name': 'qa_plugin_echo'}}] + ) + requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={}) + model = MockRuntimeModel('gpt-4o', 'test-api-key') + + chunks = [ + self._make_chunk( + tool_calls=[ + { + 'index': 0, + 'id': 'call_123', + 'type': 'function', + 'function': {'name': 'qa_plugin_echo', 'arguments': ''}, + } + ] + ), + self._make_chunk( + tool_calls=[ + { + 'index': 0, + 'id': None, + 'type': None, + 'function': {'name': None, 'arguments': '{"text":'}, + } + ] + ), + self._make_chunk( + tool_calls=[ + { + 'index': 0, + 'function': {'arguments': '"plugin-tool-ok"}'}, + } + ] + ), + self._make_chunk(finish_reason='tool_calls'), + ] + + async def _aiter(*args, **kwargs): + for c in chunks: + yield c + + query = Mock(spec=pipeline_query.Query) + query.variables = {} + messages = [provider_message.Message(role='user', content='Call the tool')] + funcs = [Mock()] + + with patch.object(litellmchat, 'acompletion', new=AsyncMock(side_effect=lambda **kw: _aiter())): + collected = [ + chunk async for chunk in requester.invoke_llm_stream( + query=query, + model=model, + messages=messages, + funcs=funcs, + ) + ] + + tool_chunks = [chunk for chunk in collected if chunk.tool_calls] + assert len(tool_chunks) == 3 + assert tool_chunks[1].tool_calls[0].id == 'call_123' + assert tool_chunks[1].tool_calls[0].function.name == 'qa_plugin_echo' + assert tool_chunks[1].tool_calls[0].function.arguments == '{"text":' + assert tool_chunks[2].tool_calls[0].function.arguments == '"plugin-tool-ok"}' + class TestProcessThinkingContent: """Test _process_thinking_content method""" @@ -499,6 +577,32 @@ class TestInvokeLLM: ) assert result_msg.tool_calls is not None + called_kwargs = litellmchat.acompletion.await_args.kwargs + assert called_kwargs['tools'] == [{'type': 'function', 'function': {'name': 'get_weather'}}] + assert called_kwargs['tool_choice'] == 'auto' + + @pytest.mark.asyncio + async def test_build_completion_args_preserves_explicit_tool_choice(self): + """Model extra args can override the default auto tool choice.""" + mock_ap = Mock() + mock_ap.tool_mgr = Mock() + mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock( + return_value=[{'type': 'function', 'function': {'name': 'get_weather'}}] + ) + + requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={}) + model = MockRuntimeModel('gpt-4o', 'test-api-key') + model.model_entity.extra_args = {'tool_choice': 'required'} + + import langbot_plugin.api.entities.builtin.resource.tool as resource_tool + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + funcs = [Mock(spec=resource_tool.LLMTool)] + messages = [provider_message.Message(role='user', content='What is the weather?')] + + args = await requester._build_completion_args(model, messages, funcs) + + assert args['tool_choice'] == 'required' @pytest.mark.asyncio async def test_invoke_llm_error_handling(self): @@ -754,6 +858,44 @@ class TestScanModels: embedding_models = [m for m in result['models'] if m['type'] == 'embedding'] assert len(embedding_models) == 1 + @pytest.mark.asyncio + async def test_scan_models_enriches_llm_abilities_and_context_length(self): + """Scanned LLM models get LiteLLM-derived abilities and context length.""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.openai.com/v1', + 'timeout': 60, + }, + ) + requester._supports_function_calling = Mock(side_effect=lambda model_id: model_id == 'gpt-4o') + requester._supports_vision = Mock(side_effect=lambda model_id: model_id == 'gpt-4o') + requester._safe_context_length = Mock(side_effect=lambda model_id: 128000 if model_id == 'gpt-4o' else None) + + mock_response = Mock() + mock_response.json = Mock( + return_value={ + 'data': [ + {'id': 'gpt-4o'}, + {'id': 'text-embedding-3-small'}, + {'id': 'bge-reranker-v2'}, + ] + } + ) + mock_response.raise_for_status = Mock() + + with patch('httpx.AsyncClient') as mock_client: + mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock()) + mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response) + + result = await requester.scan_models(api_key='test-key') + + by_id = {model['id']: model for model in result['models']} + assert by_id['gpt-4o']['abilities'] == ['func_call', 'vision'] + assert by_id['gpt-4o']['context_length'] == 128000 + assert by_id['text-embedding-3-small']['type'] == 'embedding' + assert by_id['bge-reranker-v2']['type'] == 'rerank' + @pytest.mark.asyncio async def test_scan_models_no_base_url(self): """Test scan_models without base_url raises error""" diff --git a/tests/unit_tests/provider/test_localagent_sandbox_exec.py b/tests/unit_tests/provider/test_localagent_sandbox_exec.py index daa4eb2d..08b4c540 100644 --- a/tests/unit_tests/provider/test_localagent_sandbox_exec.py +++ b/tests/unit_tests/provider/test_localagent_sandbox_exec.py @@ -10,7 +10,7 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query import langbot_plugin.api.entities.builtin.provider.message as provider_message import langbot_plugin.api.entities.builtin.provider.session as provider_session -from langbot.pkg.provider.runners.localagent import LocalAgentRunner +from langbot.pkg.provider.runners.localagent import LocalAgentRunner, _StreamAccumulator class RecordingProvider: @@ -124,6 +124,45 @@ def make_query() -> pipeline_query.Query: ) +def test_stream_accumulator_merges_fragmented_tool_call_arguments(): + accumulator = _StreamAccumulator(msg_sequence=1) + + assert ( + accumulator.add( + provider_message.MessageChunk( + role='assistant', + tool_calls=[ + provider_message.ToolCall( + id='call-1', + type='function', + function=provider_message.FunctionCall(name='exec', arguments='{"command":'), + ) + ], + ) + ) + is None + ) + + emitted = accumulator.add( + provider_message.MessageChunk( + role='assistant', + tool_calls=[ + provider_message.ToolCall( + id='call-1', + type='function', + function=provider_message.FunctionCall(name='exec', arguments='"pwd"}'), + ) + ], + is_final=True, + ) + ) + + assert emitted is not None + final_msg = accumulator.final_message() + assert final_msg.tool_calls[0].function.name == 'exec' + assert final_msg.tool_calls[0].function.arguments == '{"command":"pwd"}' + + @pytest.mark.asyncio async def test_localagent_uses_exec_for_exact_calculation(): provider = RecordingProvider() diff --git a/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx b/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx index c0899318..382b5a0f 100644 --- a/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx +++ b/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx @@ -130,32 +130,6 @@ export default function AddModelPopover({ setScanLoading(true); try { const result = await onScanModels(trigger ? undefined : tab); - - const debugData = ( - result.debug?.response as { data?: Record[] } - )?.data; - if (Array.isArray(debugData)) { - const debugMap = new Map>(); - for (const item of debugData) { - if (typeof item?.id === 'string') { - debugMap.set(item.id, item); - } - } - for (const model of result.models) { - const debugItem = debugMap.get(model.id); - if (!debugItem) continue; - const features = debugItem.features as - | Record - | undefined; - const tools = features?.tools as Record | undefined; - if (tools?.function_calling === true) { - const nextAbilities = new Set(model.abilities || []); - nextAbilities.add('func_call'); - model.abilities = [...nextAbilities]; - } - } - } - setScannedModels(result.models); setSelectedScannedModels({}); } finally {