mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-19 20:14:20 +00:00
fix(litellmchat): preserve provider_specific_fields for Gemini thought_signature (#2265)
Update _normalize_stream_tool_calls to preserve provider_specific_fields (including thought_signature) from streaming tool call chunks. Also preserve provider_specific_fields from delta in invoke_llm_stream. This ensures Gemini's thought_signature is round-tripped correctly: 1. LiteLLM extracts thought_signature from Gemini response 2. It's preserved in Message/ToolCall entities (via SDK changes) 3. _convert_messages includes it in the next request Also add unit tests for provider_specific_fields round-tripping. Fixes: langbot-app/LangBot#1899
This commit is contained in:
@@ -363,9 +363,13 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
def _normalize_stream_tool_calls(
|
||||
self,
|
||||
raw_tool_calls: typing.Any,
|
||||
tool_call_state: dict[int, dict[str, str]],
|
||||
tool_call_state: dict[int, dict[str, typing.Any]],
|
||||
) -> list[dict] | None:
|
||||
"""Fill OpenAI-style streaming tool-call deltas so MessageChunk can validate them."""
|
||||
"""Fill OpenAI-style streaming tool-call deltas so MessageChunk can validate them.
|
||||
|
||||
Also preserves provider_specific_fields (e.g., Gemini thought_signature) for
|
||||
round-tripping to the next request.
|
||||
"""
|
||||
if not raw_tool_calls:
|
||||
return None
|
||||
|
||||
@@ -376,16 +380,38 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
if not isinstance(index, int):
|
||||
index = fallback_index
|
||||
|
||||
state = tool_call_state.setdefault(index, {'id': '', 'type': 'function', 'name': ''})
|
||||
state = tool_call_state.setdefault(
|
||||
index,
|
||||
{
|
||||
'id': '',
|
||||
'type': 'function',
|
||||
'name': '',
|
||||
'provider_specific_fields': None,
|
||||
},
|
||||
)
|
||||
if tool_call.get('id'):
|
||||
state['id'] = tool_call['id']
|
||||
if tool_call.get('type'):
|
||||
state['type'] = tool_call['type']
|
||||
|
||||
# Preserve provider_specific_fields from the raw tool call
|
||||
if 'provider_specific_fields' in tool_call:
|
||||
state['provider_specific_fields'] = tool_call['provider_specific_fields']
|
||||
|
||||
function = self._as_dict(tool_call.get('function'))
|
||||
if function.get('name'):
|
||||
state['name'] = function['name']
|
||||
|
||||
# Also check function-level provider_specific_fields
|
||||
if 'provider_specific_fields' in function:
|
||||
# Merge function-level into tool-level, function-level takes precedence
|
||||
func_psf = function['provider_specific_fields']
|
||||
if state['provider_specific_fields']:
|
||||
merged = {**state['provider_specific_fields'], **func_psf}
|
||||
state['provider_specific_fields'] = merged
|
||||
else:
|
||||
state['provider_specific_fields'] = func_psf
|
||||
|
||||
arguments = function.get('arguments')
|
||||
if arguments is None:
|
||||
arguments = ''
|
||||
@@ -406,16 +432,20 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
if not state['id'] or not state['name']:
|
||||
continue
|
||||
|
||||
normalized.append(
|
||||
{
|
||||
'id': state['id'],
|
||||
'type': state['type'] or 'function',
|
||||
'function': {
|
||||
'name': state['name'],
|
||||
'arguments': arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
tool_call_dict: dict[str, typing.Any] = {
|
||||
'id': state['id'],
|
||||
'type': state['type'] or 'function',
|
||||
'function': {
|
||||
'name': state['name'],
|
||||
'arguments': arguments,
|
||||
},
|
||||
}
|
||||
|
||||
# Include provider_specific_fields if present
|
||||
if state['provider_specific_fields']:
|
||||
tool_call_dict['provider_specific_fields'] = state['provider_specific_fields']
|
||||
|
||||
normalized.append(tool_call_dict)
|
||||
|
||||
return normalized or None
|
||||
|
||||
@@ -539,7 +569,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
|
||||
chunk_idx = 0
|
||||
role = 'assistant'
|
||||
tool_call_state: dict[int, dict[str, str]] = {}
|
||||
tool_call_state: dict[int, dict[str, typing.Any]] = {}
|
||||
|
||||
try:
|
||||
response = await acompletion(**args)
|
||||
@@ -589,13 +619,17 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
chunk_data = {
|
||||
chunk_data: dict[str, typing.Any] = {
|
||||
'role': role,
|
||||
'content': delta_content if delta_content else None,
|
||||
'tool_calls': tool_calls,
|
||||
'is_final': bool(finish_reason),
|
||||
}
|
||||
|
||||
# Preserve provider_specific_fields from delta (e.g., Gemini thought_signatures)
|
||||
if delta.get('provider_specific_fields'):
|
||||
chunk_data['provider_specific_fields'] = delta['provider_specific_fields']
|
||||
|
||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
||||
yield provider_message.MessageChunk(**chunk_data)
|
||||
chunk_idx += 1
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
"""Unit tests for provider_specific_fields round-trip in LiteLLMRequester.
|
||||
|
||||
This tests the fix for GitHub issue #1899: Gemini requires thought_signature
|
||||
to be preserved across tool call rounds for function calls to work correctly.
|
||||
"""
|
||||
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
from langbot.pkg.provider.modelmgr.requesters.litellmchat import LiteLLMRequester
|
||||
|
||||
|
||||
def _make_requester() -> LiteLLMRequester:
|
||||
# _convert_messages and _normalize_stream_tool_calls do not touch instance config.
|
||||
return LiteLLMRequester.__new__(LiteLLMRequester)
|
||||
|
||||
|
||||
def test_convert_messages_preserves_tool_call_provider_specific_fields():
|
||||
"""Tool calls should retain provider_specific_fields through _convert_messages."""
|
||||
req = _make_requester()
|
||||
msg = provider_message.Message(
|
||||
role='assistant',
|
||||
content=None,
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id='call_123',
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(
|
||||
name='search',
|
||||
arguments='{"query": "test"}',
|
||||
),
|
||||
provider_specific_fields={
|
||||
'thought_signature': 'c2tpcF90aG91Z2h0X3NpZ25hdHVyZQ==',
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
out = req._convert_messages([msg])
|
||||
assert len(out) == 1
|
||||
assert out[0]['tool_calls'] is not None
|
||||
assert len(out[0]['tool_calls']) == 1
|
||||
|
||||
tc = out[0]['tool_calls'][0]
|
||||
assert tc['id'] == 'call_123'
|
||||
assert tc['function']['name'] == 'search'
|
||||
assert 'provider_specific_fields' in tc
|
||||
assert tc['provider_specific_fields']['thought_signature'] == 'c2tpcF90aG91Z2h0X3NpZ25hdHVyZQ=='
|
||||
|
||||
|
||||
def test_convert_messages_preserves_message_provider_specific_fields():
|
||||
"""Messages should retain provider_specific_fields through _convert_messages."""
|
||||
req = _make_requester()
|
||||
msg = provider_message.Message(
|
||||
role='assistant',
|
||||
content='Hello',
|
||||
provider_specific_fields={
|
||||
'thought_signatures': ['sig1', 'sig2'],
|
||||
},
|
||||
)
|
||||
out = req._convert_messages([msg])
|
||||
assert len(out) == 1
|
||||
assert 'provider_specific_fields' in out[0]
|
||||
assert out[0]['provider_specific_fields']['thought_signatures'] == ['sig1', 'sig2']
|
||||
|
||||
|
||||
def test_normalize_stream_tool_calls_preserves_provider_specific_fields():
|
||||
"""Streaming tool calls should retain provider_specific_fields."""
|
||||
req = _make_requester()
|
||||
tool_call_state: dict[int, dict] = {}
|
||||
|
||||
# Simulate first chunk with id and type
|
||||
raw_tool_calls_1 = [
|
||||
{
|
||||
'index': 0,
|
||||
'id': 'call_abc',
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'get_weather',
|
||||
'arguments': '',
|
||||
},
|
||||
'provider_specific_fields': {
|
||||
'thought_signature': 'dGVzdF9zaWduYXR1cmU=',
|
||||
},
|
||||
},
|
||||
]
|
||||
result_1 = req._normalize_stream_tool_calls(raw_tool_calls_1, tool_call_state)
|
||||
assert result_1 is not None
|
||||
assert len(result_1) == 1
|
||||
assert result_1[0]['provider_specific_fields']['thought_signature'] == 'dGVzdF9zaWduYXR1cmU='
|
||||
|
||||
# Simulate second chunk without provider_specific_fields (should be retained from state)
|
||||
raw_tool_calls_2 = [
|
||||
{
|
||||
'index': 0,
|
||||
'function': {
|
||||
'arguments': '{"city": "Tokyo"}',
|
||||
},
|
||||
},
|
||||
]
|
||||
result_2 = req._normalize_stream_tool_calls(raw_tool_calls_2, tool_call_state)
|
||||
assert result_2 is not None
|
||||
assert len(result_2) == 1
|
||||
# Should retain the provider_specific_fields from the first chunk
|
||||
assert result_2[0]['provider_specific_fields']['thought_signature'] == 'dGVzdF9zaWduYXR1cmU='
|
||||
assert result_2[0]['function']['arguments'] == '{"city": "Tokyo"}'
|
||||
|
||||
|
||||
def test_normalize_stream_tool_calls_merges_function_level_psf():
|
||||
"""Function-level provider_specific_fields should be merged into tool-level."""
|
||||
req = _make_requester()
|
||||
tool_call_state: dict[int, dict] = {}
|
||||
|
||||
raw_tool_calls = [
|
||||
{
|
||||
'index': 0,
|
||||
'id': 'call_xyz',
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'search',
|
||||
'arguments': '{}',
|
||||
'provider_specific_fields': {
|
||||
'thought_signature': 'ZnVuY19sZXZlbF9zaWc=',
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
result = req._normalize_stream_tool_calls(raw_tool_calls, tool_call_state)
|
||||
assert result is not None
|
||||
assert result[0]['provider_specific_fields']['thought_signature'] == 'ZnVuY19sZXZlbF9zaWc='
|
||||
|
||||
|
||||
def test_tool_call_roundtrip_through_message_entity():
|
||||
"""Full round-trip: LiteLLM response dict -> Message entity -> _convert_messages."""
|
||||
# Simulate what LiteLLM returns for a Gemini tool call response
|
||||
message_data = {
|
||||
'role': 'assistant',
|
||||
'content': None,
|
||||
'tool_calls': [
|
||||
{
|
||||
'id': 'call_gemini_123',
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'web_search',
|
||||
'arguments': '{"query": "test"}',
|
||||
},
|
||||
'provider_specific_fields': {
|
||||
'thought_signature': 'Z2VtaW5pX3NpZ25hdHVyZQ==',
|
||||
},
|
||||
},
|
||||
],
|
||||
'provider_specific_fields': {
|
||||
'thought_signatures': ['Z2VtaW5pX3NpZ25hdHVyZQ=='],
|
||||
},
|
||||
}
|
||||
|
||||
# Parse into Message entity (this is what invoke_llm does)
|
||||
msg = provider_message.Message(**message_data)
|
||||
|
||||
# Verify the entity has the fields
|
||||
assert msg.tool_calls is not None
|
||||
assert len(msg.tool_calls) == 1
|
||||
assert msg.tool_calls[0].provider_specific_fields is not None
|
||||
assert msg.tool_calls[0].provider_specific_fields['thought_signature'] == 'Z2VtaW5pX3NpZ25hdHVyZQ=='
|
||||
assert msg.provider_specific_fields is not None
|
||||
assert msg.provider_specific_fields['thought_signatures'] == ['Z2VtaW5pX3NpZ25hdHVyZQ==']
|
||||
|
||||
# Convert back to dict for LiteLLM (this is what _convert_messages does)
|
||||
req = _make_requester()
|
||||
out = req._convert_messages([msg])
|
||||
|
||||
# Verify the fields are preserved in the output
|
||||
assert out[0]['tool_calls'][0]['provider_specific_fields']['thought_signature'] == 'Z2VtaW5pX3NpZ25hdHVyZQ=='
|
||||
assert out[0]['provider_specific_fields']['thought_signatures'] == ['Z2VtaW5pX3NpZ25hdHVyZQ==']
|
||||
Reference in New Issue
Block a user