refactor(provider): simplify litellm capabilities

This commit is contained in:
huanghuoguoguo
2026-06-06 00:21:19 +08:00
parent 39673444d2
commit 7fb3cfa638
7 changed files with 443 additions and 141 deletions
+144 -2
View File
@@ -68,6 +68,12 @@ class TestBuildLiteLLMModelName:
result = requester._build_litellm_model_name('gpt-4o')
assert result == 'openai/gpt-4o'
def test_avoid_duplicate_provider_prefix(self):
"""Test model name with an existing matching provider prefix."""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'})
result = requester._build_litellm_model_name('openai/gpt-4o')
assert result == 'openai/gpt-4o'
def test_override_provider(self):
"""Test override provider via parameter"""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'})
@@ -151,7 +157,7 @@ class TestInvokeLLMStreamUsage:
calls record 0 tokens.
"""
def _make_chunk(self, *, content=None, finish_reason=None, usage=None, has_choice=True):
def _make_chunk(self, *, content=None, tool_calls=None, finish_reason=None, usage=None, has_choice=True):
chunk = Mock()
if usage is not None:
chunk.usage = usage
@@ -161,7 +167,7 @@ class TestInvokeLLMStreamUsage:
choice = Mock()
delta = Mock()
delta.model_dump = Mock(
return_value={'role': 'assistant', 'content': content, 'tool_calls': None}
return_value={'role': 'assistant', 'content': content, 'tool_calls': tool_calls}
)
choice.delta = delta
choice.finish_reason = finish_reason
@@ -250,6 +256,78 @@ class TestInvokeLLMStreamUsage:
assert query.variables['_stream_usage']['total_tokens'] == 12
@pytest.mark.asyncio
async def test_stream_tool_call_delta_missing_id_and_name(self):
"""LiteLLM may stream tool-call argument deltas with id/name set to None."""
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
mock_ap = Mock()
mock_ap.tool_mgr = Mock()
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(
return_value=[{'type': 'function', 'function': {'name': 'qa_plugin_echo'}}]
)
requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={})
model = MockRuntimeModel('gpt-4o', 'test-api-key')
chunks = [
self._make_chunk(
tool_calls=[
{
'index': 0,
'id': 'call_123',
'type': 'function',
'function': {'name': 'qa_plugin_echo', 'arguments': ''},
}
]
),
self._make_chunk(
tool_calls=[
{
'index': 0,
'id': None,
'type': None,
'function': {'name': None, 'arguments': '{"text":'},
}
]
),
self._make_chunk(
tool_calls=[
{
'index': 0,
'function': {'arguments': '"plugin-tool-ok"}'},
}
]
),
self._make_chunk(finish_reason='tool_calls'),
]
async def _aiter(*args, **kwargs):
for c in chunks:
yield c
query = Mock(spec=pipeline_query.Query)
query.variables = {}
messages = [provider_message.Message(role='user', content='Call the tool')]
funcs = [Mock()]
with patch.object(litellmchat, 'acompletion', new=AsyncMock(side_effect=lambda **kw: _aiter())):
collected = [
chunk async for chunk in requester.invoke_llm_stream(
query=query,
model=model,
messages=messages,
funcs=funcs,
)
]
tool_chunks = [chunk for chunk in collected if chunk.tool_calls]
assert len(tool_chunks) == 3
assert tool_chunks[1].tool_calls[0].id == 'call_123'
assert tool_chunks[1].tool_calls[0].function.name == 'qa_plugin_echo'
assert tool_chunks[1].tool_calls[0].function.arguments == '{"text":'
assert tool_chunks[2].tool_calls[0].function.arguments == '"plugin-tool-ok"}'
class TestProcessThinkingContent:
"""Test _process_thinking_content method"""
@@ -499,6 +577,32 @@ class TestInvokeLLM:
)
assert result_msg.tool_calls is not None
called_kwargs = litellmchat.acompletion.await_args.kwargs
assert called_kwargs['tools'] == [{'type': 'function', 'function': {'name': 'get_weather'}}]
assert called_kwargs['tool_choice'] == 'auto'
@pytest.mark.asyncio
async def test_build_completion_args_preserves_explicit_tool_choice(self):
"""Model extra args can override the default auto tool choice."""
mock_ap = Mock()
mock_ap.tool_mgr = Mock()
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(
return_value=[{'type': 'function', 'function': {'name': 'get_weather'}}]
)
requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={})
model = MockRuntimeModel('gpt-4o', 'test-api-key')
model.model_entity.extra_args = {'tool_choice': 'required'}
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.provider.message as provider_message
funcs = [Mock(spec=resource_tool.LLMTool)]
messages = [provider_message.Message(role='user', content='What is the weather?')]
args = await requester._build_completion_args(model, messages, funcs)
assert args['tool_choice'] == 'required'
@pytest.mark.asyncio
async def test_invoke_llm_error_handling(self):
@@ -754,6 +858,44 @@ class TestScanModels:
embedding_models = [m for m in result['models'] if m['type'] == 'embedding']
assert len(embedding_models) == 1
@pytest.mark.asyncio
async def test_scan_models_enriches_llm_abilities_and_context_length(self):
"""Scanned LLM models get LiteLLM-derived abilities and context length."""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.openai.com/v1',
'timeout': 60,
},
)
requester._supports_function_calling = Mock(side_effect=lambda model_id: model_id == 'gpt-4o')
requester._supports_vision = Mock(side_effect=lambda model_id: model_id == 'gpt-4o')
requester._safe_context_length = Mock(side_effect=lambda model_id: 128000 if model_id == 'gpt-4o' else None)
mock_response = Mock()
mock_response.json = Mock(
return_value={
'data': [
{'id': 'gpt-4o'},
{'id': 'text-embedding-3-small'},
{'id': 'bge-reranker-v2'},
]
}
)
mock_response.raise_for_status = Mock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock())
mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
result = await requester.scan_models(api_key='test-key')
by_id = {model['id']: model for model in result['models']}
assert by_id['gpt-4o']['abilities'] == ['func_call', 'vision']
assert by_id['gpt-4o']['context_length'] == 128000
assert by_id['text-embedding-3-small']['type'] == 'embedding'
assert by_id['bge-reranker-v2']['type'] == 'rerank'
@pytest.mark.asyncio
async def test_scan_models_no_base_url(self):
"""Test scan_models without base_url raises error"""
@@ -10,7 +10,7 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.provider.session as provider_session
from langbot.pkg.provider.runners.localagent import LocalAgentRunner
from langbot.pkg.provider.runners.localagent import LocalAgentRunner, _StreamAccumulator
class RecordingProvider:
@@ -124,6 +124,45 @@ def make_query() -> pipeline_query.Query:
)
def test_stream_accumulator_merges_fragmented_tool_call_arguments():
accumulator = _StreamAccumulator(msg_sequence=1)
assert (
accumulator.add(
provider_message.MessageChunk(
role='assistant',
tool_calls=[
provider_message.ToolCall(
id='call-1',
type='function',
function=provider_message.FunctionCall(name='exec', arguments='{"command":'),
)
],
)
)
is None
)
emitted = accumulator.add(
provider_message.MessageChunk(
role='assistant',
tool_calls=[
provider_message.ToolCall(
id='call-1',
type='function',
function=provider_message.FunctionCall(name='exec', arguments='"pwd"}'),
)
],
is_final=True,
)
)
assert emitted is not None
final_msg = accumulator.final_message()
assert final_msg.tool_calls[0].function.name == 'exec'
assert final_msg.tool_calls[0].function.arguments == '{"command":"pwd"}'
@pytest.mark.asyncio
async def test_localagent_uses_exec_for_exact_calculation():
provider = RecordingProvider()