fix(litellmchat): preserve provider_specific_fields for Gemini thought_signature (#2265)

Update _normalize_stream_tool_calls to preserve provider_specific_fields
(including thought_signature) from streaming tool call chunks. Also preserve
provider_specific_fields from delta in invoke_llm_stream.

This ensures Gemini's thought_signature is round-tripped correctly:
1. LiteLLM extracts thought_signature from Gemini response
2. It's preserved in Message/ToolCall entities (via SDK changes)
3. _convert_messages includes it in the next request

Also add unit tests for provider_specific_fields round-tripping.

Fixes: langbot-app/LangBot#1899
This commit is contained in:
huanghuoguoguo
2026-06-19 15:26:12 +00:00
committed by GitHub
parent 492827ea75
commit acfac42107
2 changed files with 221 additions and 15 deletions
@@ -363,9 +363,13 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
def _normalize_stream_tool_calls(
self,
raw_tool_calls: typing.Any,
tool_call_state: dict[int, dict[str, str]],
tool_call_state: dict[int, dict[str, typing.Any]],
) -> list[dict] | None:
"""Fill OpenAI-style streaming tool-call deltas so MessageChunk can validate them."""
"""Fill OpenAI-style streaming tool-call deltas so MessageChunk can validate them.
Also preserves provider_specific_fields (e.g., Gemini thought_signature) for
round-tripping to the next request.
"""
if not raw_tool_calls:
return None
@@ -376,16 +380,38 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
if not isinstance(index, int):
index = fallback_index
state = tool_call_state.setdefault(index, {'id': '', 'type': 'function', 'name': ''})
state = tool_call_state.setdefault(
index,
{
'id': '',
'type': 'function',
'name': '',
'provider_specific_fields': None,
},
)
if tool_call.get('id'):
state['id'] = tool_call['id']
if tool_call.get('type'):
state['type'] = tool_call['type']
# Preserve provider_specific_fields from the raw tool call
if 'provider_specific_fields' in tool_call:
state['provider_specific_fields'] = tool_call['provider_specific_fields']
function = self._as_dict(tool_call.get('function'))
if function.get('name'):
state['name'] = function['name']
# Also check function-level provider_specific_fields
if 'provider_specific_fields' in function:
# Merge function-level into tool-level, function-level takes precedence
func_psf = function['provider_specific_fields']
if state['provider_specific_fields']:
merged = {**state['provider_specific_fields'], **func_psf}
state['provider_specific_fields'] = merged
else:
state['provider_specific_fields'] = func_psf
arguments = function.get('arguments')
if arguments is None:
arguments = ''
@@ -406,16 +432,20 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
if not state['id'] or not state['name']:
continue
normalized.append(
{
'id': state['id'],
'type': state['type'] or 'function',
'function': {
'name': state['name'],
'arguments': arguments,
},
}
)
tool_call_dict: dict[str, typing.Any] = {
'id': state['id'],
'type': state['type'] or 'function',
'function': {
'name': state['name'],
'arguments': arguments,
},
}
# Include provider_specific_fields if present
if state['provider_specific_fields']:
tool_call_dict['provider_specific_fields'] = state['provider_specific_fields']
normalized.append(tool_call_dict)
return normalized or None
@@ -539,7 +569,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
chunk_idx = 0
role = 'assistant'
tool_call_state: dict[int, dict[str, str]] = {}
tool_call_state: dict[int, dict[str, typing.Any]] = {}
try:
response = await acompletion(**args)
@@ -589,13 +619,17 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
chunk_idx += 1
continue
chunk_data = {
chunk_data: dict[str, typing.Any] = {
'role': role,
'content': delta_content if delta_content else None,
'tool_calls': tool_calls,
'is_final': bool(finish_reason),
}
# Preserve provider_specific_fields from delta (e.g., Gemini thought_signatures)
if delta.get('provider_specific_fields'):
chunk_data['provider_specific_fields'] = delta['provider_specific_fields']
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
yield provider_message.MessageChunk(**chunk_data)
chunk_idx += 1