mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-26 23:44:19 +00:00
257 lines
9.1 KiB
Python
257 lines
9.1 KiB
Python
"""Unit tests for LiteLLMRequester message/tool conversion.
|
|
|
|
This includes provider_specific_fields round-trip coverage for GitHub issue
|
|
#1899 and token counting preflight behavior for AgentRunner context budgeting.
|
|
"""
|
|
|
|
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
|
|
import pytest
|
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
|
|
|
from langbot.pkg.provider.modelmgr import requester as model_requester
|
|
from langbot.pkg.provider.modelmgr.requesters.litellmchat import LiteLLMRequester
|
|
|
|
|
|
def _make_requester() -> LiteLLMRequester:
|
|
# _convert_messages and _normalize_stream_tool_calls do not touch instance config.
|
|
return LiteLLMRequester.__new__(LiteLLMRequester)
|
|
|
|
|
|
def _make_configured_requester() -> LiteLLMRequester:
|
|
req = LiteLLMRequester.__new__(LiteLLMRequester)
|
|
req.requester_cfg = {
|
|
'base_url': '',
|
|
'timeout': 120,
|
|
'custom_llm_provider': 'openai',
|
|
'drop_params': False,
|
|
'num_retries': 0,
|
|
'api_version': '',
|
|
}
|
|
req.ap = SimpleNamespace(
|
|
tool_mgr=SimpleNamespace(
|
|
generate_tools_for_openai=AsyncMock(
|
|
return_value=[
|
|
{
|
|
'type': 'function',
|
|
'function': {
|
|
'name': 'search',
|
|
'description': 'Search',
|
|
'parameters': {'type': 'object'},
|
|
},
|
|
}
|
|
]
|
|
)
|
|
)
|
|
)
|
|
return req
|
|
|
|
|
|
def _make_runtime_model() -> model_requester.RuntimeLLMModel:
|
|
provider = SimpleNamespace(token_mgr=SimpleNamespace(get_token=Mock(return_value='sk-test')))
|
|
return SimpleNamespace(
|
|
model_entity=SimpleNamespace(
|
|
name='gpt-4.1',
|
|
extra_args={'temperature': 0.2},
|
|
),
|
|
provider=provider,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_count_tokens_uses_litellm_counter_with_request_messages_and_tools():
|
|
"""Token preflight uses the same LiteLLM request shape as chat completion."""
|
|
req = _make_configured_requester()
|
|
model = _make_runtime_model()
|
|
tool = resource_tool.LLMTool(
|
|
name='search',
|
|
human_desc='Search',
|
|
description='Search',
|
|
parameters={'type': 'object'},
|
|
func=lambda **kwargs: None,
|
|
)
|
|
|
|
with patch('langbot.pkg.provider.modelmgr.requesters.litellmchat.litellm.token_counter', return_value=42) as counter:
|
|
tokens = await req.count_tokens(
|
|
model=model,
|
|
messages=[
|
|
provider_message.Message(
|
|
role='user',
|
|
content=[
|
|
provider_message.ContentElement(type='text', text='hello'),
|
|
provider_message.ContentElement(type='file_url', file_url='https://example.test/a.pdf'),
|
|
],
|
|
)
|
|
],
|
|
funcs=[tool],
|
|
extra_args={'presence_penalty': 0.1},
|
|
)
|
|
|
|
assert tokens == 42
|
|
counter.assert_called_once()
|
|
kwargs = counter.call_args.kwargs
|
|
assert kwargs['model'] == 'openai/gpt-4.1'
|
|
assert kwargs['messages'] == [{'role': 'user', 'content': [{'type': 'text', 'text': 'hello'}]}]
|
|
assert kwargs['tools'][0]['function']['name'] == 'search'
|
|
assert kwargs['tool_choice'] == 'auto'
|
|
|
|
|
|
def test_convert_messages_preserves_tool_call_provider_specific_fields():
|
|
"""Tool calls should retain provider_specific_fields through _convert_messages."""
|
|
req = _make_requester()
|
|
msg = provider_message.Message(
|
|
role='assistant',
|
|
content=None,
|
|
tool_calls=[
|
|
provider_message.ToolCall(
|
|
id='call_123',
|
|
type='function',
|
|
function=provider_message.FunctionCall(
|
|
name='search',
|
|
arguments='{"query": "test"}',
|
|
),
|
|
provider_specific_fields={
|
|
'thought_signature': 'c2tpcF90aG91Z2h0X3NpZ25hdHVyZQ==',
|
|
},
|
|
),
|
|
],
|
|
)
|
|
out = req._convert_messages([msg])
|
|
assert len(out) == 1
|
|
assert out[0]['tool_calls'] is not None
|
|
assert len(out[0]['tool_calls']) == 1
|
|
|
|
tc = out[0]['tool_calls'][0]
|
|
assert tc['id'] == 'call_123'
|
|
assert tc['function']['name'] == 'search'
|
|
assert 'provider_specific_fields' in tc
|
|
assert tc['provider_specific_fields']['thought_signature'] == 'c2tpcF90aG91Z2h0X3NpZ25hdHVyZQ=='
|
|
|
|
|
|
def test_convert_messages_preserves_message_provider_specific_fields():
|
|
"""Messages should retain provider_specific_fields through _convert_messages."""
|
|
req = _make_requester()
|
|
msg = provider_message.Message(
|
|
role='assistant',
|
|
content='Hello',
|
|
provider_specific_fields={
|
|
'thought_signatures': ['sig1', 'sig2'],
|
|
},
|
|
)
|
|
out = req._convert_messages([msg])
|
|
assert len(out) == 1
|
|
assert 'provider_specific_fields' in out[0]
|
|
assert out[0]['provider_specific_fields']['thought_signatures'] == ['sig1', 'sig2']
|
|
|
|
|
|
def test_normalize_stream_tool_calls_preserves_provider_specific_fields():
|
|
"""Streaming tool calls should retain provider_specific_fields."""
|
|
req = _make_requester()
|
|
tool_call_state: dict[int, dict] = {}
|
|
|
|
# Simulate first chunk with id and type
|
|
raw_tool_calls_1 = [
|
|
{
|
|
'index': 0,
|
|
'id': 'call_abc',
|
|
'type': 'function',
|
|
'function': {
|
|
'name': 'get_weather',
|
|
'arguments': '',
|
|
},
|
|
'provider_specific_fields': {
|
|
'thought_signature': 'dGVzdF9zaWduYXR1cmU=',
|
|
},
|
|
},
|
|
]
|
|
result_1 = req._normalize_stream_tool_calls(raw_tool_calls_1, tool_call_state)
|
|
assert result_1 is not None
|
|
assert len(result_1) == 1
|
|
assert result_1[0]['provider_specific_fields']['thought_signature'] == 'dGVzdF9zaWduYXR1cmU='
|
|
|
|
# Simulate second chunk without provider_specific_fields (should be retained from state)
|
|
raw_tool_calls_2 = [
|
|
{
|
|
'index': 0,
|
|
'function': {
|
|
'arguments': '{"city": "Tokyo"}',
|
|
},
|
|
},
|
|
]
|
|
result_2 = req._normalize_stream_tool_calls(raw_tool_calls_2, tool_call_state)
|
|
assert result_2 is not None
|
|
assert len(result_2) == 1
|
|
# Should retain the provider_specific_fields from the first chunk
|
|
assert result_2[0]['provider_specific_fields']['thought_signature'] == 'dGVzdF9zaWduYXR1cmU='
|
|
assert result_2[0]['function']['arguments'] == '{"city": "Tokyo"}'
|
|
|
|
|
|
def test_normalize_stream_tool_calls_merges_function_level_psf():
|
|
"""Function-level provider_specific_fields should be merged into tool-level."""
|
|
req = _make_requester()
|
|
tool_call_state: dict[int, dict] = {}
|
|
|
|
raw_tool_calls = [
|
|
{
|
|
'index': 0,
|
|
'id': 'call_xyz',
|
|
'type': 'function',
|
|
'function': {
|
|
'name': 'search',
|
|
'arguments': '{}',
|
|
'provider_specific_fields': {
|
|
'thought_signature': 'ZnVuY19sZXZlbF9zaWc=',
|
|
},
|
|
},
|
|
},
|
|
]
|
|
result = req._normalize_stream_tool_calls(raw_tool_calls, tool_call_state)
|
|
assert result is not None
|
|
assert result[0]['provider_specific_fields']['thought_signature'] == 'ZnVuY19sZXZlbF9zaWc='
|
|
|
|
|
|
def test_tool_call_roundtrip_through_message_entity():
|
|
"""Full round-trip: LiteLLM response dict -> Message entity -> _convert_messages."""
|
|
# Simulate what LiteLLM returns for a Gemini tool call response
|
|
message_data = {
|
|
'role': 'assistant',
|
|
'content': None,
|
|
'tool_calls': [
|
|
{
|
|
'id': 'call_gemini_123',
|
|
'type': 'function',
|
|
'function': {
|
|
'name': 'web_search',
|
|
'arguments': '{"query": "test"}',
|
|
},
|
|
'provider_specific_fields': {
|
|
'thought_signature': 'Z2VtaW5pX3NpZ25hdHVyZQ==',
|
|
},
|
|
},
|
|
],
|
|
'provider_specific_fields': {
|
|
'thought_signatures': ['Z2VtaW5pX3NpZ25hdHVyZQ=='],
|
|
},
|
|
}
|
|
|
|
# Parse into Message entity (this is what invoke_llm does)
|
|
msg = provider_message.Message(**message_data)
|
|
|
|
# Verify the entity has the fields
|
|
assert msg.tool_calls is not None
|
|
assert len(msg.tool_calls) == 1
|
|
assert msg.tool_calls[0].provider_specific_fields is not None
|
|
assert msg.tool_calls[0].provider_specific_fields['thought_signature'] == 'Z2VtaW5pX3NpZ25hdHVyZQ=='
|
|
assert msg.provider_specific_fields is not None
|
|
assert msg.provider_specific_fields['thought_signatures'] == ['Z2VtaW5pX3NpZ25hdHVyZQ==']
|
|
|
|
# Convert back to dict for LiteLLM (this is what _convert_messages does)
|
|
req = _make_requester()
|
|
out = req._convert_messages([msg])
|
|
|
|
# Verify the fields are preserved in the output
|
|
assert out[0]['tool_calls'][0]['provider_specific_fields']['thought_signature'] == 'Z2VtaW5pX3NpZ25hdHVyZQ=='
|
|
assert out[0]['provider_specific_fields']['thought_signatures'] == ['Z2VtaW5pX3NpZ25hdHVyZQ==']
|