mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-17 03:04:20 +00:00
refactor(provider): simplify litellm capabilities
This commit is contained in:
@@ -68,6 +68,12 @@ class TestBuildLiteLLMModelName:
|
||||
result = requester._build_litellm_model_name('gpt-4o')
|
||||
assert result == 'openai/gpt-4o'
|
||||
|
||||
def test_avoid_duplicate_provider_prefix(self):
|
||||
"""Test model name with an existing matching provider prefix."""
|
||||
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'})
|
||||
result = requester._build_litellm_model_name('openai/gpt-4o')
|
||||
assert result == 'openai/gpt-4o'
|
||||
|
||||
def test_override_provider(self):
|
||||
"""Test override provider via parameter"""
|
||||
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'})
|
||||
@@ -151,7 +157,7 @@ class TestInvokeLLMStreamUsage:
|
||||
calls record 0 tokens.
|
||||
"""
|
||||
|
||||
def _make_chunk(self, *, content=None, finish_reason=None, usage=None, has_choice=True):
|
||||
def _make_chunk(self, *, content=None, tool_calls=None, finish_reason=None, usage=None, has_choice=True):
|
||||
chunk = Mock()
|
||||
if usage is not None:
|
||||
chunk.usage = usage
|
||||
@@ -161,7 +167,7 @@ class TestInvokeLLMStreamUsage:
|
||||
choice = Mock()
|
||||
delta = Mock()
|
||||
delta.model_dump = Mock(
|
||||
return_value={'role': 'assistant', 'content': content, 'tool_calls': None}
|
||||
return_value={'role': 'assistant', 'content': content, 'tool_calls': tool_calls}
|
||||
)
|
||||
choice.delta = delta
|
||||
choice.finish_reason = finish_reason
|
||||
@@ -250,6 +256,78 @@ class TestInvokeLLMStreamUsage:
|
||||
|
||||
assert query.variables['_stream_usage']['total_tokens'] == 12
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_tool_call_delta_missing_id_and_name(self):
|
||||
"""LiteLLM may stream tool-call argument deltas with id/name set to None."""
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
mock_ap = Mock()
|
||||
mock_ap.tool_mgr = Mock()
|
||||
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(
|
||||
return_value=[{'type': 'function', 'function': {'name': 'qa_plugin_echo'}}]
|
||||
)
|
||||
requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={})
|
||||
model = MockRuntimeModel('gpt-4o', 'test-api-key')
|
||||
|
||||
chunks = [
|
||||
self._make_chunk(
|
||||
tool_calls=[
|
||||
{
|
||||
'index': 0,
|
||||
'id': 'call_123',
|
||||
'type': 'function',
|
||||
'function': {'name': 'qa_plugin_echo', 'arguments': ''},
|
||||
}
|
||||
]
|
||||
),
|
||||
self._make_chunk(
|
||||
tool_calls=[
|
||||
{
|
||||
'index': 0,
|
||||
'id': None,
|
||||
'type': None,
|
||||
'function': {'name': None, 'arguments': '{"text":'},
|
||||
}
|
||||
]
|
||||
),
|
||||
self._make_chunk(
|
||||
tool_calls=[
|
||||
{
|
||||
'index': 0,
|
||||
'function': {'arguments': '"plugin-tool-ok"}'},
|
||||
}
|
||||
]
|
||||
),
|
||||
self._make_chunk(finish_reason='tool_calls'),
|
||||
]
|
||||
|
||||
async def _aiter(*args, **kwargs):
|
||||
for c in chunks:
|
||||
yield c
|
||||
|
||||
query = Mock(spec=pipeline_query.Query)
|
||||
query.variables = {}
|
||||
messages = [provider_message.Message(role='user', content='Call the tool')]
|
||||
funcs = [Mock()]
|
||||
|
||||
with patch.object(litellmchat, 'acompletion', new=AsyncMock(side_effect=lambda **kw: _aiter())):
|
||||
collected = [
|
||||
chunk async for chunk in requester.invoke_llm_stream(
|
||||
query=query,
|
||||
model=model,
|
||||
messages=messages,
|
||||
funcs=funcs,
|
||||
)
|
||||
]
|
||||
|
||||
tool_chunks = [chunk for chunk in collected if chunk.tool_calls]
|
||||
assert len(tool_chunks) == 3
|
||||
assert tool_chunks[1].tool_calls[0].id == 'call_123'
|
||||
assert tool_chunks[1].tool_calls[0].function.name == 'qa_plugin_echo'
|
||||
assert tool_chunks[1].tool_calls[0].function.arguments == '{"text":'
|
||||
assert tool_chunks[2].tool_calls[0].function.arguments == '"plugin-tool-ok"}'
|
||||
|
||||
|
||||
class TestProcessThinkingContent:
|
||||
"""Test _process_thinking_content method"""
|
||||
@@ -499,6 +577,32 @@ class TestInvokeLLM:
|
||||
)
|
||||
|
||||
assert result_msg.tool_calls is not None
|
||||
called_kwargs = litellmchat.acompletion.await_args.kwargs
|
||||
assert called_kwargs['tools'] == [{'type': 'function', 'function': {'name': 'get_weather'}}]
|
||||
assert called_kwargs['tool_choice'] == 'auto'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_completion_args_preserves_explicit_tool_choice(self):
|
||||
"""Model extra args can override the default auto tool choice."""
|
||||
mock_ap = Mock()
|
||||
mock_ap.tool_mgr = Mock()
|
||||
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(
|
||||
return_value=[{'type': 'function', 'function': {'name': 'get_weather'}}]
|
||||
)
|
||||
|
||||
requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={})
|
||||
model = MockRuntimeModel('gpt-4o', 'test-api-key')
|
||||
model.model_entity.extra_args = {'tool_choice': 'required'}
|
||||
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
funcs = [Mock(spec=resource_tool.LLMTool)]
|
||||
messages = [provider_message.Message(role='user', content='What is the weather?')]
|
||||
|
||||
args = await requester._build_completion_args(model, messages, funcs)
|
||||
|
||||
assert args['tool_choice'] == 'required'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_error_handling(self):
|
||||
@@ -754,6 +858,44 @@ class TestScanModels:
|
||||
embedding_models = [m for m in result['models'] if m['type'] == 'embedding']
|
||||
assert len(embedding_models) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_models_enriches_llm_abilities_and_context_length(self):
|
||||
"""Scanned LLM models get LiteLLM-derived abilities and context length."""
|
||||
requester = litellmchat.LiteLLMRequester(
|
||||
ap=Mock(),
|
||||
config={
|
||||
'base_url': 'https://api.openai.com/v1',
|
||||
'timeout': 60,
|
||||
},
|
||||
)
|
||||
requester._supports_function_calling = Mock(side_effect=lambda model_id: model_id == 'gpt-4o')
|
||||
requester._supports_vision = Mock(side_effect=lambda model_id: model_id == 'gpt-4o')
|
||||
requester._safe_context_length = Mock(side_effect=lambda model_id: 128000 if model_id == 'gpt-4o' else None)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json = Mock(
|
||||
return_value={
|
||||
'data': [
|
||||
{'id': 'gpt-4o'},
|
||||
{'id': 'text-embedding-3-small'},
|
||||
{'id': 'bge-reranker-v2'},
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock())
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await requester.scan_models(api_key='test-key')
|
||||
|
||||
by_id = {model['id']: model for model in result['models']}
|
||||
assert by_id['gpt-4o']['abilities'] == ['func_call', 'vision']
|
||||
assert by_id['gpt-4o']['context_length'] == 128000
|
||||
assert by_id['text-embedding-3-small']['type'] == 'embedding'
|
||||
assert by_id['bge-reranker-v2']['type'] == 'rerank'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_models_no_base_url(self):
|
||||
"""Test scan_models without base_url raises error"""
|
||||
|
||||
@@ -10,7 +10,7 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
|
||||
from langbot.pkg.provider.runners.localagent import LocalAgentRunner
|
||||
from langbot.pkg.provider.runners.localagent import LocalAgentRunner, _StreamAccumulator
|
||||
|
||||
|
||||
class RecordingProvider:
|
||||
@@ -124,6 +124,45 @@ def make_query() -> pipeline_query.Query:
|
||||
)
|
||||
|
||||
|
||||
def test_stream_accumulator_merges_fragmented_tool_call_arguments():
|
||||
accumulator = _StreamAccumulator(msg_sequence=1)
|
||||
|
||||
assert (
|
||||
accumulator.add(
|
||||
provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id='call-1',
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(name='exec', arguments='{"command":'),
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
emitted = accumulator.add(
|
||||
provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id='call-1',
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(name='exec', arguments='"pwd"}'),
|
||||
)
|
||||
],
|
||||
is_final=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert emitted is not None
|
||||
final_msg = accumulator.final_message()
|
||||
assert final_msg.tool_calls[0].function.name == 'exec'
|
||||
assert final_msg.tool_calls[0].function.arguments == '{"command":"pwd"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_localagent_uses_exec_for_exact_calculation():
|
||||
provider = RecordingProvider()
|
||||
|
||||
Reference in New Issue
Block a user