mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-06 22:06:03 +00:00
refactor(provider): simplify litellm capabilities
This commit is contained in:
@@ -109,7 +109,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
if llm_model:
|
||||
query.use_llm_model_uuid = llm_model.model_entity.uuid
|
||||
|
||||
if llm_model.model_entity.abilities.__contains__('func_call'):
|
||||
if 'func_call' in (llm_model.model_entity.abilities or []):
|
||||
# Get bound plugins and MCP servers for filtering tools
|
||||
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
|
||||
bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None)
|
||||
@@ -162,7 +162,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
if (
|
||||
selected_runner == 'local-agent'
|
||||
and llm_model
|
||||
and not llm_model.model_entity.abilities.__contains__('vision')
|
||||
and 'vision' not in (llm_model.model_entity.abilities or [])
|
||||
):
|
||||
for msg in query.messages:
|
||||
if isinstance(msg.content, list):
|
||||
@@ -181,7 +181,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
plain_text += me.text
|
||||
elif isinstance(me, platform_message.Image):
|
||||
if selected_runner != 'local-agent' or (
|
||||
llm_model and llm_model.model_entity.abilities.__contains__('vision')
|
||||
llm_model and 'vision' in (llm_model.model_entity.abilities or [])
|
||||
):
|
||||
if me.base64 is not None:
|
||||
content_list.append(provider_message.ContentElement.from_image_base64(me.base64))
|
||||
@@ -202,7 +202,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
content_list.append(provider_message.ContentElement.from_text(msg.text))
|
||||
elif isinstance(msg, platform_message.Image):
|
||||
if selected_runner != 'local-agent' or (
|
||||
llm_model and llm_model.model_entity.abilities.__contains__('vision')
|
||||
llm_model and 'vision' in (llm_model.model_entity.abilities or [])
|
||||
):
|
||||
if msg.base64 is not None:
|
||||
content_list.append(provider_message.ContentElement.from_image_base64(msg.base64))
|
||||
|
||||
@@ -37,16 +37,39 @@ class ModelManager:
|
||||
self.requester_components = []
|
||||
self.requester_dict = {}
|
||||
|
||||
@staticmethod
|
||||
def _get_litellm_provider_from_manifest(component: engine.Component | None) -> str | None:
|
||||
if component is None:
|
||||
return None
|
||||
|
||||
spec = getattr(component, 'spec', None) or {}
|
||||
litellm_provider = None
|
||||
|
||||
if isinstance(spec, dict):
|
||||
litellm_provider = spec.get('litellm_provider')
|
||||
else:
|
||||
getter = getattr(spec, 'get', None)
|
||||
if callable(getter):
|
||||
try:
|
||||
litellm_provider = getter('litellm_provider')
|
||||
except Exception:
|
||||
litellm_provider = None
|
||||
|
||||
if isinstance(litellm_provider, str) and litellm_provider:
|
||||
return litellm_provider
|
||||
return None
|
||||
|
||||
async def initialize(self):
|
||||
self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester')
|
||||
|
||||
requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {}
|
||||
for component in self.requester_components:
|
||||
# Skip components that use litellm_provider (they will use litellmchat.py instead)
|
||||
if component.spec.get('litellm_provider'):
|
||||
litellm_provider = self._get_litellm_provider_from_manifest(component)
|
||||
if litellm_provider:
|
||||
self.ap.logger.debug(
|
||||
f'Skipping Python class loading for {component.metadata.name} '
|
||||
f'(uses litellm_provider={component.spec.get("litellm_provider")})'
|
||||
f'(uses litellm_provider={litellm_provider})'
|
||||
)
|
||||
continue
|
||||
requester_dict[component.metadata.name] = component.get_python_component_class()
|
||||
@@ -303,17 +326,18 @@ class ModelManager:
|
||||
|
||||
# Get requester manifest to check for litellm_provider
|
||||
requester_manifest = self.get_available_requester_manifest_by_name(provider_entity.requester)
|
||||
litellm_provider = self._get_litellm_provider_from_manifest(requester_manifest)
|
||||
|
||||
# Build config from base_url
|
||||
config = {'base_url': provider_entity.base_url}
|
||||
|
||||
# Check if requester manifest specifies litellm_provider
|
||||
if requester_manifest and requester_manifest.spec.get('litellm_provider'):
|
||||
if litellm_provider:
|
||||
from .requesters import litellmchat
|
||||
|
||||
# Use unified LiteLLMRequester with provider prefix
|
||||
# Map litellm_provider (YAML spec) to custom_llm_provider (config)
|
||||
config['custom_llm_provider'] = requester_manifest.spec['litellm_provider']
|
||||
config['custom_llm_provider'] = litellm_provider
|
||||
requester_inst = litellmchat.LiteLLMRequester(
|
||||
ap=self.ap,
|
||||
config=config,
|
||||
|
||||
@@ -16,6 +16,9 @@ import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
"""LiteLLM unified API requester supporting chat, embedding, and rerank."""
|
||||
|
||||
_EMBEDDING_MODEL_HINTS = ('embedding', 'embed', 'bge-', 'e5-', 'm3e', 'gte-', 'text-embedding')
|
||||
_RERANK_MODEL_HINTS = ('rerank', 're-rank', 're_rank')
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base_url': '',
|
||||
'timeout': 120,
|
||||
@@ -36,10 +39,90 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
provider = custom_llm_provider or self.requester_cfg.get('custom_llm_provider', '')
|
||||
if provider:
|
||||
# LiteLLM format: provider/model_name
|
||||
if model_name.startswith(f'{provider}/'):
|
||||
return model_name
|
||||
return f'{provider}/{model_name}'
|
||||
# If no custom provider, assume model_name already includes prefix or is OpenAI-compatible
|
||||
return model_name
|
||||
|
||||
def _get_custom_llm_provider(self) -> str | None:
|
||||
return self.requester_cfg.get('custom_llm_provider') or None
|
||||
|
||||
def _safe_litellm_bool_helper(self, helper_name: str, model_name: str) -> bool:
|
||||
"""Call a LiteLLM boolean capability helper without letting metadata gaps fail requests."""
|
||||
helper = getattr(litellm, helper_name, None)
|
||||
if not callable(helper):
|
||||
return False
|
||||
|
||||
provider = self._get_custom_llm_provider()
|
||||
candidates: list[tuple[str, str | None]] = [(model_name, provider)]
|
||||
litellm_model_name = self._build_litellm_model_name(model_name)
|
||||
if litellm_model_name != model_name:
|
||||
candidates.append((litellm_model_name, None))
|
||||
|
||||
for candidate_model, candidate_provider in candidates:
|
||||
try:
|
||||
if bool(helper(model=candidate_model, custom_llm_provider=candidate_provider)):
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
|
||||
def _safe_context_length(self, model_name: str) -> int | None:
|
||||
helper = getattr(litellm, 'get_max_tokens', None)
|
||||
if not callable(helper):
|
||||
return None
|
||||
|
||||
candidates = [model_name]
|
||||
litellm_model_name = self._build_litellm_model_name(model_name)
|
||||
if litellm_model_name != model_name:
|
||||
candidates.append(litellm_model_name)
|
||||
|
||||
for candidate in candidates:
|
||||
try:
|
||||
max_tokens = helper(candidate)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(max_tokens, int) and max_tokens > 0:
|
||||
return max_tokens
|
||||
return None
|
||||
|
||||
def _supports_function_calling(self, model_name: str) -> bool:
|
||||
return self._safe_litellm_bool_helper('supports_function_calling', model_name)
|
||||
|
||||
def _supports_vision(self, model_name: str) -> bool:
|
||||
return self._safe_litellm_bool_helper('supports_vision', model_name)
|
||||
|
||||
def _infer_model_type(self, model_id: str) -> str:
|
||||
normalized_id = (model_id or '').lower()
|
||||
if any(kw in normalized_id for kw in self._RERANK_MODEL_HINTS):
|
||||
return 'rerank'
|
||||
if any(kw in normalized_id for kw in self._EMBEDDING_MODEL_HINTS):
|
||||
return 'embedding'
|
||||
return 'llm'
|
||||
|
||||
def _enrich_scanned_model(self, model_id: str) -> dict[str, typing.Any]:
|
||||
model_type = self._infer_model_type(model_id)
|
||||
scanned_model: dict[str, typing.Any] = {
|
||||
'id': model_id,
|
||||
'name': model_id,
|
||||
'type': model_type,
|
||||
}
|
||||
|
||||
if model_type == 'llm':
|
||||
abilities = []
|
||||
if self._supports_function_calling(model_id):
|
||||
abilities.append('func_call')
|
||||
if self._supports_vision(model_id):
|
||||
abilities.append('vision')
|
||||
scanned_model['abilities'] = abilities
|
||||
|
||||
context_length = self._safe_context_length(model_id)
|
||||
if context_length is not None:
|
||||
scanned_model['context_length'] = context_length
|
||||
|
||||
return scanned_model
|
||||
|
||||
def _convert_messages(self, messages: typing.List[provider_message.Message]) -> list[dict]:
|
||||
"""Convert LangBot messages to LiteLLM/OpenAI format."""
|
||||
req_messages = []
|
||||
@@ -121,6 +204,64 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
"""Extract usage info from a non-streaming LiteLLM response."""
|
||||
return self._normalize_usage(getattr(response, 'usage', None))
|
||||
|
||||
@staticmethod
|
||||
def _as_dict(value: typing.Any) -> dict:
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if hasattr(value, 'model_dump'):
|
||||
return value.model_dump()
|
||||
return {}
|
||||
|
||||
def _normalize_stream_tool_calls(
|
||||
self,
|
||||
raw_tool_calls: typing.Any,
|
||||
tool_call_state: dict[int, dict[str, str]],
|
||||
) -> list[dict] | None:
|
||||
"""Fill OpenAI-style streaming tool-call deltas so MessageChunk can validate them."""
|
||||
if not raw_tool_calls:
|
||||
return None
|
||||
|
||||
normalized = []
|
||||
for fallback_index, raw_tool_call in enumerate(raw_tool_calls):
|
||||
tool_call = self._as_dict(raw_tool_call)
|
||||
index = tool_call.get('index')
|
||||
if not isinstance(index, int):
|
||||
index = fallback_index
|
||||
|
||||
state = tool_call_state.setdefault(index, {'id': '', 'type': 'function', 'name': ''})
|
||||
if tool_call.get('id'):
|
||||
state['id'] = tool_call['id']
|
||||
if tool_call.get('type'):
|
||||
state['type'] = tool_call['type']
|
||||
|
||||
function = self._as_dict(tool_call.get('function'))
|
||||
if function.get('name'):
|
||||
state['name'] = function['name']
|
||||
|
||||
arguments = function.get('arguments')
|
||||
if arguments is None:
|
||||
arguments = ''
|
||||
elif not isinstance(arguments, str):
|
||||
arguments = str(arguments)
|
||||
|
||||
if not state['id'] or not state['name']:
|
||||
continue
|
||||
|
||||
normalized.append(
|
||||
{
|
||||
'id': state['id'],
|
||||
'type': state['type'] or 'function',
|
||||
'function': {
|
||||
'name': state['name'],
|
||||
'arguments': arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return normalized or None
|
||||
|
||||
def _build_common_args(self, args: dict, include_retry_params: bool = True) -> dict:
|
||||
"""Apply common requester config to args dict."""
|
||||
if self.requester_cfg.get('base_url'):
|
||||
@@ -189,6 +330,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
|
||||
if tools:
|
||||
args['tools'] = tools
|
||||
args.setdefault('tool_choice', 'auto')
|
||||
|
||||
return args
|
||||
|
||||
@@ -240,6 +382,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
|
||||
chunk_idx = 0
|
||||
role = 'assistant'
|
||||
tool_call_state: dict[int, dict[str, str]] = {}
|
||||
|
||||
try:
|
||||
response = await acompletion(**args)
|
||||
@@ -283,14 +426,16 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
# Use reasoning_content as the displayed content
|
||||
delta_content = reasoning_content
|
||||
|
||||
if chunk_idx == 0 and not delta_content and not delta.get('tool_calls'):
|
||||
tool_calls = self._normalize_stream_tool_calls(delta.get('tool_calls'), tool_call_state)
|
||||
|
||||
if chunk_idx == 0 and not delta_content and not tool_calls:
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
chunk_data = {
|
||||
'role': role,
|
||||
'content': delta_content if delta_content else None,
|
||||
'tool_calls': delta.get('tool_calls'),
|
||||
'tool_calls': tool_calls,
|
||||
'is_final': bool(finish_reason),
|
||||
}
|
||||
|
||||
@@ -412,18 +557,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
# Infer model type
|
||||
normalized_id = (model_id or '').lower()
|
||||
embedding_keywords = ('embedding', 'embed', 'bge-', 'e5-', 'm3e', 'gte-', 'text-embedding')
|
||||
model_type = 'embedding' if any(kw in normalized_id for kw in embedding_keywords) else 'llm'
|
||||
|
||||
models.append(
|
||||
{
|
||||
'id': model_id,
|
||||
'name': model_id,
|
||||
'type': model_type,
|
||||
}
|
||||
)
|
||||
models.append(self._enrich_scanned_model(model_id))
|
||||
|
||||
models.sort(key=lambda x: (x['type'] != 'llm', x['name'].lower()))
|
||||
|
||||
|
||||
@@ -41,6 +41,64 @@ SANDBOX_EXEC_SYSTEM_GUIDANCE = (
|
||||
MAX_TOOL_CALL_ROUNDS = 128
|
||||
|
||||
|
||||
def _model_has_ability(model: modelmgr_requester.RuntimeLLMModel, ability: str) -> bool:
|
||||
return ability in (model.model_entity.abilities or [])
|
||||
|
||||
|
||||
class _StreamAccumulator:
|
||||
"""Accumulate streamed content and fragmented OpenAI-style tool calls."""
|
||||
|
||||
def __init__(self, msg_sequence: int = 0, initial_content: str | None = None):
|
||||
self.tool_calls_map: dict[str, provider_message.ToolCall] = {}
|
||||
self.msg_idx = 0
|
||||
self.accumulated_content = initial_content or ''
|
||||
self.last_role = 'assistant'
|
||||
self.msg_sequence = msg_sequence
|
||||
|
||||
def add(self, msg: provider_message.MessageChunk) -> provider_message.MessageChunk | None:
|
||||
self.msg_idx += 1
|
||||
|
||||
if msg.role:
|
||||
self.last_role = msg.role
|
||||
|
||||
if msg.content:
|
||||
self.accumulated_content += msg.content
|
||||
|
||||
if msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
if tool_call.id not in self.tool_calls_map:
|
||||
self.tool_calls_map[tool_call.id] = provider_message.ToolCall(
|
||||
id=tool_call.id,
|
||||
type=tool_call.type,
|
||||
function=provider_message.FunctionCall(
|
||||
name=tool_call.function.name if tool_call.function else '',
|
||||
arguments='',
|
||||
),
|
||||
)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
self.tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
|
||||
|
||||
if self.msg_idx % 8 == 0 or msg.is_final:
|
||||
self.msg_sequence += 1
|
||||
return provider_message.MessageChunk(
|
||||
role=self.last_role,
|
||||
content=self.accumulated_content,
|
||||
tool_calls=list(self.tool_calls_map.values()) if (self.tool_calls_map and msg.is_final) else None,
|
||||
is_final=msg.is_final,
|
||||
msg_sequence=self.msg_sequence,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def final_message(self) -> provider_message.MessageChunk:
|
||||
return provider_message.MessageChunk(
|
||||
role=self.last_role,
|
||||
content=self.accumulated_content,
|
||||
tool_calls=list(self.tool_calls_map.values()) if self.tool_calls_map else None,
|
||||
msg_sequence=self.msg_sequence,
|
||||
)
|
||||
|
||||
|
||||
@runner.runner_class('local-agent')
|
||||
class LocalAgentRunner(runner.RequestRunner):
|
||||
"""Local agent request runner"""
|
||||
@@ -105,7 +163,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
query,
|
||||
model,
|
||||
messages,
|
||||
funcs if model.model_entity.abilities.__contains__('func_call') else [],
|
||||
funcs if _model_has_ability(model, 'func_call') else [],
|
||||
extra_args=model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
@@ -135,7 +193,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
query,
|
||||
model,
|
||||
messages,
|
||||
funcs if model.model_entity.abilities.__contains__('func_call') else [],
|
||||
funcs if _model_has_ability(model, 'func_call') else [],
|
||||
extra_args=model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
@@ -302,11 +360,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
final_msg = msg
|
||||
else:
|
||||
# Streaming: invoke with fallback
|
||||
tool_calls_map: dict[str, provider_message.ToolCall] = {}
|
||||
msg_idx = 0
|
||||
accumulated_content = ''
|
||||
last_role = 'assistant'
|
||||
msg_sequence = 1
|
||||
stream_accumulator = _StreamAccumulator(msg_sequence=1)
|
||||
|
||||
stream_src, use_llm_model = await self._invoke_stream_with_fallback(
|
||||
query,
|
||||
@@ -316,44 +370,12 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
remove_think,
|
||||
)
|
||||
async for msg in stream_src:
|
||||
msg_idx = msg_idx + 1
|
||||
|
||||
if msg.role:
|
||||
last_role = msg.role
|
||||
|
||||
if msg.content:
|
||||
accumulated_content += msg.content
|
||||
|
||||
if msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
if tool_call.id not in tool_calls_map:
|
||||
tool_calls_map[tool_call.id] = provider_message.ToolCall(
|
||||
id=tool_call.id,
|
||||
type=tool_call.type,
|
||||
function=provider_message.FunctionCall(
|
||||
name=tool_call.function.name if tool_call.function else '', arguments=''
|
||||
),
|
||||
)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
|
||||
|
||||
if msg_idx % 8 == 0 or msg.is_final:
|
||||
msg_sequence += 1
|
||||
yield provider_message.MessageChunk(
|
||||
role=last_role,
|
||||
content=accumulated_content,
|
||||
tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None,
|
||||
is_final=msg.is_final,
|
||||
msg_sequence=msg_sequence,
|
||||
)
|
||||
chunk = stream_accumulator.add(msg)
|
||||
if chunk:
|
||||
yield chunk
|
||||
initial_response_emitted = True
|
||||
|
||||
final_msg = provider_message.MessageChunk(
|
||||
role=last_role,
|
||||
content=accumulated_content,
|
||||
tool_calls=list(tool_calls_map.values()) if tool_calls_map else None,
|
||||
msg_sequence=msg_sequence,
|
||||
)
|
||||
final_msg = stream_accumulator.final_message()
|
||||
|
||||
pending_tool_calls = final_msg.tool_calls
|
||||
first_content = final_msg.content
|
||||
@@ -438,69 +460,36 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
)
|
||||
|
||||
if is_stream:
|
||||
tool_calls_map = {}
|
||||
msg_idx = 0
|
||||
accumulated_content = ''
|
||||
last_role = 'assistant'
|
||||
msg_sequence = first_end_sequence
|
||||
stream_accumulator = _StreamAccumulator(
|
||||
msg_sequence=first_end_sequence,
|
||||
initial_content=first_content,
|
||||
)
|
||||
|
||||
tool_stream_src = use_llm_model.provider.invoke_llm_stream(
|
||||
query,
|
||||
use_llm_model,
|
||||
req_messages,
|
||||
query.use_funcs if use_llm_model.model_entity.abilities.__contains__('func_call') else [],
|
||||
query.use_funcs
|
||||
if _model_has_ability(use_llm_model, 'func_call')
|
||||
else [],
|
||||
extra_args=use_llm_model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
async for msg in tool_stream_src:
|
||||
msg_idx += 1
|
||||
chunk = stream_accumulator.add(msg)
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
if msg.role:
|
||||
last_role = msg.role
|
||||
|
||||
# Prepend first-round content on first chunk of tool-call round
|
||||
if msg_idx == 1:
|
||||
accumulated_content = first_content if first_content is not None else accumulated_content
|
||||
|
||||
if msg.content:
|
||||
accumulated_content += msg.content
|
||||
|
||||
if msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
if tool_call.id not in tool_calls_map:
|
||||
tool_calls_map[tool_call.id] = provider_message.ToolCall(
|
||||
id=tool_call.id,
|
||||
type=tool_call.type,
|
||||
function=provider_message.FunctionCall(
|
||||
name=tool_call.function.name if tool_call.function else '', arguments=''
|
||||
),
|
||||
)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
|
||||
|
||||
if msg_idx % 8 == 0 or msg.is_final:
|
||||
msg_sequence += 1
|
||||
yield provider_message.MessageChunk(
|
||||
role=last_role,
|
||||
content=accumulated_content,
|
||||
tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None,
|
||||
is_final=msg.is_final,
|
||||
msg_sequence=msg_sequence,
|
||||
)
|
||||
|
||||
final_msg = provider_message.MessageChunk(
|
||||
role=last_role,
|
||||
content=accumulated_content,
|
||||
tool_calls=list(tool_calls_map.values()) if tool_calls_map else None,
|
||||
msg_sequence=msg_sequence,
|
||||
)
|
||||
final_msg = stream_accumulator.final_message()
|
||||
else:
|
||||
# Non-streaming: use committed model directly (no fallback in tool loop)
|
||||
msg = await use_llm_model.provider.invoke_llm(
|
||||
query,
|
||||
use_llm_model,
|
||||
req_messages,
|
||||
query.use_funcs if use_llm_model.model_entity.abilities.__contains__('func_call') else [],
|
||||
query.use_funcs
|
||||
if _model_has_ability(use_llm_model, 'func_call')
|
||||
else [],
|
||||
extra_args=use_llm_model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
|
||||
@@ -68,6 +68,12 @@ class TestBuildLiteLLMModelName:
|
||||
result = requester._build_litellm_model_name('gpt-4o')
|
||||
assert result == 'openai/gpt-4o'
|
||||
|
||||
def test_avoid_duplicate_provider_prefix(self):
|
||||
"""Test model name with an existing matching provider prefix."""
|
||||
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'})
|
||||
result = requester._build_litellm_model_name('openai/gpt-4o')
|
||||
assert result == 'openai/gpt-4o'
|
||||
|
||||
def test_override_provider(self):
|
||||
"""Test override provider via parameter"""
|
||||
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'})
|
||||
@@ -151,7 +157,7 @@ class TestInvokeLLMStreamUsage:
|
||||
calls record 0 tokens.
|
||||
"""
|
||||
|
||||
def _make_chunk(self, *, content=None, finish_reason=None, usage=None, has_choice=True):
|
||||
def _make_chunk(self, *, content=None, tool_calls=None, finish_reason=None, usage=None, has_choice=True):
|
||||
chunk = Mock()
|
||||
if usage is not None:
|
||||
chunk.usage = usage
|
||||
@@ -161,7 +167,7 @@ class TestInvokeLLMStreamUsage:
|
||||
choice = Mock()
|
||||
delta = Mock()
|
||||
delta.model_dump = Mock(
|
||||
return_value={'role': 'assistant', 'content': content, 'tool_calls': None}
|
||||
return_value={'role': 'assistant', 'content': content, 'tool_calls': tool_calls}
|
||||
)
|
||||
choice.delta = delta
|
||||
choice.finish_reason = finish_reason
|
||||
@@ -250,6 +256,78 @@ class TestInvokeLLMStreamUsage:
|
||||
|
||||
assert query.variables['_stream_usage']['total_tokens'] == 12
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_tool_call_delta_missing_id_and_name(self):
|
||||
"""LiteLLM may stream tool-call argument deltas with id/name set to None."""
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
mock_ap = Mock()
|
||||
mock_ap.tool_mgr = Mock()
|
||||
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(
|
||||
return_value=[{'type': 'function', 'function': {'name': 'qa_plugin_echo'}}]
|
||||
)
|
||||
requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={})
|
||||
model = MockRuntimeModel('gpt-4o', 'test-api-key')
|
||||
|
||||
chunks = [
|
||||
self._make_chunk(
|
||||
tool_calls=[
|
||||
{
|
||||
'index': 0,
|
||||
'id': 'call_123',
|
||||
'type': 'function',
|
||||
'function': {'name': 'qa_plugin_echo', 'arguments': ''},
|
||||
}
|
||||
]
|
||||
),
|
||||
self._make_chunk(
|
||||
tool_calls=[
|
||||
{
|
||||
'index': 0,
|
||||
'id': None,
|
||||
'type': None,
|
||||
'function': {'name': None, 'arguments': '{"text":'},
|
||||
}
|
||||
]
|
||||
),
|
||||
self._make_chunk(
|
||||
tool_calls=[
|
||||
{
|
||||
'index': 0,
|
||||
'function': {'arguments': '"plugin-tool-ok"}'},
|
||||
}
|
||||
]
|
||||
),
|
||||
self._make_chunk(finish_reason='tool_calls'),
|
||||
]
|
||||
|
||||
async def _aiter(*args, **kwargs):
|
||||
for c in chunks:
|
||||
yield c
|
||||
|
||||
query = Mock(spec=pipeline_query.Query)
|
||||
query.variables = {}
|
||||
messages = [provider_message.Message(role='user', content='Call the tool')]
|
||||
funcs = [Mock()]
|
||||
|
||||
with patch.object(litellmchat, 'acompletion', new=AsyncMock(side_effect=lambda **kw: _aiter())):
|
||||
collected = [
|
||||
chunk async for chunk in requester.invoke_llm_stream(
|
||||
query=query,
|
||||
model=model,
|
||||
messages=messages,
|
||||
funcs=funcs,
|
||||
)
|
||||
]
|
||||
|
||||
tool_chunks = [chunk for chunk in collected if chunk.tool_calls]
|
||||
assert len(tool_chunks) == 3
|
||||
assert tool_chunks[1].tool_calls[0].id == 'call_123'
|
||||
assert tool_chunks[1].tool_calls[0].function.name == 'qa_plugin_echo'
|
||||
assert tool_chunks[1].tool_calls[0].function.arguments == '{"text":'
|
||||
assert tool_chunks[2].tool_calls[0].function.arguments == '"plugin-tool-ok"}'
|
||||
|
||||
|
||||
class TestProcessThinkingContent:
|
||||
"""Test _process_thinking_content method"""
|
||||
@@ -499,6 +577,32 @@ class TestInvokeLLM:
|
||||
)
|
||||
|
||||
assert result_msg.tool_calls is not None
|
||||
called_kwargs = litellmchat.acompletion.await_args.kwargs
|
||||
assert called_kwargs['tools'] == [{'type': 'function', 'function': {'name': 'get_weather'}}]
|
||||
assert called_kwargs['tool_choice'] == 'auto'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_completion_args_preserves_explicit_tool_choice(self):
|
||||
"""Model extra args can override the default auto tool choice."""
|
||||
mock_ap = Mock()
|
||||
mock_ap.tool_mgr = Mock()
|
||||
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(
|
||||
return_value=[{'type': 'function', 'function': {'name': 'get_weather'}}]
|
||||
)
|
||||
|
||||
requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={})
|
||||
model = MockRuntimeModel('gpt-4o', 'test-api-key')
|
||||
model.model_entity.extra_args = {'tool_choice': 'required'}
|
||||
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
funcs = [Mock(spec=resource_tool.LLMTool)]
|
||||
messages = [provider_message.Message(role='user', content='What is the weather?')]
|
||||
|
||||
args = await requester._build_completion_args(model, messages, funcs)
|
||||
|
||||
assert args['tool_choice'] == 'required'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_error_handling(self):
|
||||
@@ -754,6 +858,44 @@ class TestScanModels:
|
||||
embedding_models = [m for m in result['models'] if m['type'] == 'embedding']
|
||||
assert len(embedding_models) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_models_enriches_llm_abilities_and_context_length(self):
|
||||
"""Scanned LLM models get LiteLLM-derived abilities and context length."""
|
||||
requester = litellmchat.LiteLLMRequester(
|
||||
ap=Mock(),
|
||||
config={
|
||||
'base_url': 'https://api.openai.com/v1',
|
||||
'timeout': 60,
|
||||
},
|
||||
)
|
||||
requester._supports_function_calling = Mock(side_effect=lambda model_id: model_id == 'gpt-4o')
|
||||
requester._supports_vision = Mock(side_effect=lambda model_id: model_id == 'gpt-4o')
|
||||
requester._safe_context_length = Mock(side_effect=lambda model_id: 128000 if model_id == 'gpt-4o' else None)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json = Mock(
|
||||
return_value={
|
||||
'data': [
|
||||
{'id': 'gpt-4o'},
|
||||
{'id': 'text-embedding-3-small'},
|
||||
{'id': 'bge-reranker-v2'},
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock())
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await requester.scan_models(api_key='test-key')
|
||||
|
||||
by_id = {model['id']: model for model in result['models']}
|
||||
assert by_id['gpt-4o']['abilities'] == ['func_call', 'vision']
|
||||
assert by_id['gpt-4o']['context_length'] == 128000
|
||||
assert by_id['text-embedding-3-small']['type'] == 'embedding'
|
||||
assert by_id['bge-reranker-v2']['type'] == 'rerank'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_models_no_base_url(self):
|
||||
"""Test scan_models without base_url raises error"""
|
||||
|
||||
@@ -10,7 +10,7 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
|
||||
from langbot.pkg.provider.runners.localagent import LocalAgentRunner
|
||||
from langbot.pkg.provider.runners.localagent import LocalAgentRunner, _StreamAccumulator
|
||||
|
||||
|
||||
class RecordingProvider:
|
||||
@@ -124,6 +124,45 @@ def make_query() -> pipeline_query.Query:
|
||||
)
|
||||
|
||||
|
||||
def test_stream_accumulator_merges_fragmented_tool_call_arguments():
|
||||
accumulator = _StreamAccumulator(msg_sequence=1)
|
||||
|
||||
assert (
|
||||
accumulator.add(
|
||||
provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id='call-1',
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(name='exec', arguments='{"command":'),
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
emitted = accumulator.add(
|
||||
provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id='call-1',
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(name='exec', arguments='"pwd"}'),
|
||||
)
|
||||
],
|
||||
is_final=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert emitted is not None
|
||||
final_msg = accumulator.final_message()
|
||||
assert final_msg.tool_calls[0].function.name == 'exec'
|
||||
assert final_msg.tool_calls[0].function.arguments == '{"command":"pwd"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_localagent_uses_exec_for_exact_calculation():
|
||||
provider = RecordingProvider()
|
||||
|
||||
@@ -130,32 +130,6 @@ export default function AddModelPopover({
|
||||
setScanLoading(true);
|
||||
try {
|
||||
const result = await onScanModels(trigger ? undefined : tab);
|
||||
|
||||
const debugData = (
|
||||
result.debug?.response as { data?: Record<string, unknown>[] }
|
||||
)?.data;
|
||||
if (Array.isArray(debugData)) {
|
||||
const debugMap = new Map<string, Record<string, unknown>>();
|
||||
for (const item of debugData) {
|
||||
if (typeof item?.id === 'string') {
|
||||
debugMap.set(item.id, item);
|
||||
}
|
||||
}
|
||||
for (const model of result.models) {
|
||||
const debugItem = debugMap.get(model.id);
|
||||
if (!debugItem) continue;
|
||||
const features = debugItem.features as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
const tools = features?.tools as Record<string, unknown> | undefined;
|
||||
if (tools?.function_calling === true) {
|
||||
const nextAbilities = new Set(model.abilities || []);
|
||||
nextAbilities.add('func_call');
|
||||
model.abilities = [...nextAbilities];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setScannedModels(result.models);
|
||||
setSelectedScannedModels({});
|
||||
} finally {
|
||||
|
||||
Reference in New Issue
Block a user