diff --git a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py index a6c09b7e7..c1b5ae0b6 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py @@ -363,9 +363,13 @@ class LiteLLMRequester(requester.ProviderAPIRequester): def _normalize_stream_tool_calls( self, raw_tool_calls: typing.Any, - tool_call_state: dict[int, dict[str, str]], + tool_call_state: dict[int, dict[str, typing.Any]], ) -> list[dict] | None: - """Fill OpenAI-style streaming tool-call deltas so MessageChunk can validate them.""" + """Fill OpenAI-style streaming tool-call deltas so MessageChunk can validate them. + + Also preserves provider_specific_fields (e.g., Gemini thought_signature) for + round-tripping to the next request. + """ if not raw_tool_calls: return None @@ -376,16 +380,38 @@ class LiteLLMRequester(requester.ProviderAPIRequester): if not isinstance(index, int): index = fallback_index - state = tool_call_state.setdefault(index, {'id': '', 'type': 'function', 'name': ''}) + state = tool_call_state.setdefault( + index, + { + 'id': '', + 'type': 'function', + 'name': '', + 'provider_specific_fields': None, + }, + ) if tool_call.get('id'): state['id'] = tool_call['id'] if tool_call.get('type'): state['type'] = tool_call['type'] + # Preserve provider_specific_fields from the raw tool call + if 'provider_specific_fields' in tool_call: + state['provider_specific_fields'] = tool_call['provider_specific_fields'] + function = self._as_dict(tool_call.get('function')) if function.get('name'): state['name'] = function['name'] + # Also check function-level provider_specific_fields + if 'provider_specific_fields' in function: + # Merge function-level into tool-level, function-level takes precedence + func_psf = function['provider_specific_fields'] + if state['provider_specific_fields']: + merged = {**state['provider_specific_fields'], **func_psf} + state['provider_specific_fields'] = merged + else: + state['provider_specific_fields'] = func_psf + arguments = function.get('arguments') if arguments is None: arguments = '' @@ -406,16 +432,20 @@ class LiteLLMRequester(requester.ProviderAPIRequester): if not state['id'] or not state['name']: continue - normalized.append( - { - 'id': state['id'], - 'type': state['type'] or 'function', - 'function': { - 'name': state['name'], - 'arguments': arguments, - }, - } - ) + tool_call_dict: dict[str, typing.Any] = { + 'id': state['id'], + 'type': state['type'] or 'function', + 'function': { + 'name': state['name'], + 'arguments': arguments, + }, + } + + # Include provider_specific_fields if present + if state['provider_specific_fields']: + tool_call_dict['provider_specific_fields'] = state['provider_specific_fields'] + + normalized.append(tool_call_dict) return normalized or None @@ -539,7 +569,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester): chunk_idx = 0 role = 'assistant' - tool_call_state: dict[int, dict[str, str]] = {} + tool_call_state: dict[int, dict[str, typing.Any]] = {} try: response = await acompletion(**args) @@ -589,13 +619,17 @@ class LiteLLMRequester(requester.ProviderAPIRequester): chunk_idx += 1 continue - chunk_data = { + chunk_data: dict[str, typing.Any] = { 'role': role, 'content': delta_content if delta_content else None, 'tool_calls': tool_calls, 'is_final': bool(finish_reason), } + # Preserve provider_specific_fields from delta (e.g., Gemini thought_signatures) + if delta.get('provider_specific_fields'): + chunk_data['provider_specific_fields'] = delta['provider_specific_fields'] + chunk_data = {k: v for k, v in chunk_data.items() if v is not None} yield provider_message.MessageChunk(**chunk_data) chunk_idx += 1 diff --git a/tests/unit_tests/provider/test_provider_specific_fields.py b/tests/unit_tests/provider/test_provider_specific_fields.py new file mode 100644 index 000000000..2520e608e --- /dev/null +++ b/tests/unit_tests/provider/test_provider_specific_fields.py @@ -0,0 +1,172 @@ +"""Unit tests for provider_specific_fields round-trip in LiteLLMRequester. + +This tests the fix for GitHub issue #1899: Gemini requires thought_signature +to be preserved across tool call rounds for function calls to work correctly. +""" + +import langbot_plugin.api.entities.builtin.provider.message as provider_message + +from langbot.pkg.provider.modelmgr.requesters.litellmchat import LiteLLMRequester + + +def _make_requester() -> LiteLLMRequester: + # _convert_messages and _normalize_stream_tool_calls do not touch instance config. + return LiteLLMRequester.__new__(LiteLLMRequester) + + +def test_convert_messages_preserves_tool_call_provider_specific_fields(): + """Tool calls should retain provider_specific_fields through _convert_messages.""" + req = _make_requester() + msg = provider_message.Message( + role='assistant', + content=None, + tool_calls=[ + provider_message.ToolCall( + id='call_123', + type='function', + function=provider_message.FunctionCall( + name='search', + arguments='{"query": "test"}', + ), + provider_specific_fields={ + 'thought_signature': 'c2tpcF90aG91Z2h0X3NpZ25hdHVyZQ==', + }, + ), + ], + ) + out = req._convert_messages([msg]) + assert len(out) == 1 + assert out[0]['tool_calls'] is not None + assert len(out[0]['tool_calls']) == 1 + + tc = out[0]['tool_calls'][0] + assert tc['id'] == 'call_123' + assert tc['function']['name'] == 'search' + assert 'provider_specific_fields' in tc + assert tc['provider_specific_fields']['thought_signature'] == 'c2tpcF90aG91Z2h0X3NpZ25hdHVyZQ==' + + +def test_convert_messages_preserves_message_provider_specific_fields(): + """Messages should retain provider_specific_fields through _convert_messages.""" + req = _make_requester() + msg = provider_message.Message( + role='assistant', + content='Hello', + provider_specific_fields={ + 'thought_signatures': ['sig1', 'sig2'], + }, + ) + out = req._convert_messages([msg]) + assert len(out) == 1 + assert 'provider_specific_fields' in out[0] + assert out[0]['provider_specific_fields']['thought_signatures'] == ['sig1', 'sig2'] + + +def test_normalize_stream_tool_calls_preserves_provider_specific_fields(): + """Streaming tool calls should retain provider_specific_fields.""" + req = _make_requester() + tool_call_state: dict[int, dict] = {} + + # Simulate first chunk with id and type + raw_tool_calls_1 = [ + { + 'index': 0, + 'id': 'call_abc', + 'type': 'function', + 'function': { + 'name': 'get_weather', + 'arguments': '', + }, + 'provider_specific_fields': { + 'thought_signature': 'dGVzdF9zaWduYXR1cmU=', + }, + }, + ] + result_1 = req._normalize_stream_tool_calls(raw_tool_calls_1, tool_call_state) + assert result_1 is not None + assert len(result_1) == 1 + assert result_1[0]['provider_specific_fields']['thought_signature'] == 'dGVzdF9zaWduYXR1cmU=' + + # Simulate second chunk without provider_specific_fields (should be retained from state) + raw_tool_calls_2 = [ + { + 'index': 0, + 'function': { + 'arguments': '{"city": "Tokyo"}', + }, + }, + ] + result_2 = req._normalize_stream_tool_calls(raw_tool_calls_2, tool_call_state) + assert result_2 is not None + assert len(result_2) == 1 + # Should retain the provider_specific_fields from the first chunk + assert result_2[0]['provider_specific_fields']['thought_signature'] == 'dGVzdF9zaWduYXR1cmU=' + assert result_2[0]['function']['arguments'] == '{"city": "Tokyo"}' + + +def test_normalize_stream_tool_calls_merges_function_level_psf(): + """Function-level provider_specific_fields should be merged into tool-level.""" + req = _make_requester() + tool_call_state: dict[int, dict] = {} + + raw_tool_calls = [ + { + 'index': 0, + 'id': 'call_xyz', + 'type': 'function', + 'function': { + 'name': 'search', + 'arguments': '{}', + 'provider_specific_fields': { + 'thought_signature': 'ZnVuY19sZXZlbF9zaWc=', + }, + }, + }, + ] + result = req._normalize_stream_tool_calls(raw_tool_calls, tool_call_state) + assert result is not None + assert result[0]['provider_specific_fields']['thought_signature'] == 'ZnVuY19sZXZlbF9zaWc=' + + +def test_tool_call_roundtrip_through_message_entity(): + """Full round-trip: LiteLLM response dict -> Message entity -> _convert_messages.""" + # Simulate what LiteLLM returns for a Gemini tool call response + message_data = { + 'role': 'assistant', + 'content': None, + 'tool_calls': [ + { + 'id': 'call_gemini_123', + 'type': 'function', + 'function': { + 'name': 'web_search', + 'arguments': '{"query": "test"}', + }, + 'provider_specific_fields': { + 'thought_signature': 'Z2VtaW5pX3NpZ25hdHVyZQ==', + }, + }, + ], + 'provider_specific_fields': { + 'thought_signatures': ['Z2VtaW5pX3NpZ25hdHVyZQ=='], + }, + } + + # Parse into Message entity (this is what invoke_llm does) + msg = provider_message.Message(**message_data) + + # Verify the entity has the fields + assert msg.tool_calls is not None + assert len(msg.tool_calls) == 1 + assert msg.tool_calls[0].provider_specific_fields is not None + assert msg.tool_calls[0].provider_specific_fields['thought_signature'] == 'Z2VtaW5pX3NpZ25hdHVyZQ==' + assert msg.provider_specific_fields is not None + assert msg.provider_specific_fields['thought_signatures'] == ['Z2VtaW5pX3NpZ25hdHVyZQ=='] + + # Convert back to dict for LiteLLM (this is what _convert_messages does) + req = _make_requester() + out = req._convert_messages([msg]) + + # Verify the fields are preserved in the output + assert out[0]['tool_calls'][0]['provider_specific_fields']['thought_signature'] == 'Z2VtaW5pX3NpZ25hdHVyZQ==' + assert out[0]['provider_specific_fields']['thought_signatures'] == ['Z2VtaW5pX3NpZ25hdHVyZQ==']