Files
LangBot/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py
huanghuoguoguo 9ecb587ac0 refactor(provider): use LiteLLM as unified LLM requester backend (#2150)
* refactor(provider): use LiteLLM as unified LLM requester backend

  - Replace 23+ individual requester implementations with unified litellmchat.py
  - Add litellm_provider field to 27 YAML manifests for provider routing
  - Delete redundant requester subclasses
  - Add unit tests for LiteLLMRequester (29 tests)
  - Fix num_retries parameter name (was max_retries)
  - Fix exception handling order for subclass exceptions

  LiteLLM provides unified API for 100+ providers, eliminating need for
  provider-specific requesters.

* fix: ruff format provider.py

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* refactor(provider): simplify LiteLLM requester usage handling

  - Remove unused Anthropic-specific tool schema generation
  - Share completion argument construction between normal and streaming calls
  - Use LiteLLM/OpenAI native usage fields for monitoring
  - Collect stream token usage from LiteLLM stream_options
  - Update LiteLLM requester tests for unified usage fields

* restore: restore deleted provider requester files

Restore individual provider requester implementations that were
removed in de61b5d3. These files coexist with the unified
litellmchat.py backend.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* feat: update requesters and improve provider selection UI

- Added `litellm_provider` field to various requesters' YAML configurations.
- Removed obsolete Python requester files for OpenRouter, PPIO, QHAIGC, ShengSuanYun, SiliconFlow, Space, TokenPony, VolcArk, and Xai.
- Introduced new requesters for Tencent and Together AI with corresponding YAML configurations and SVG icons.
- Enhanced the ProviderForm component to include a searchable dropdown for selecting providers, improving user experience.
- Updated localization files to include search provider text for both English and Chinese.

* fix(provider): align litellm rebase with master

* fix(provider): capture streaming token usage; add token observability

The LiteLLM streaming requester only captured usage when a chunk had an
empty `choices` list. Many OpenAI-compatible gateways (e.g. new-api) and
providers send the final usage payload in a chunk that still carries an
empty-delta choice, so streamed calls always recorded 0 tokens in the
monitoring logs/dashboard (non-streaming worked).

- Capture stream usage whenever a chunk carries it, regardless of choices
- Add robust _normalize_usage (dict/obj shapes, derive missing total_tokens)
- Register litellm in bootutils/deps.py (was in pyproject only)
- Add MonitoringService.get_token_statistics + /monitoring/token-statistics
  endpoint: summary, per-model breakdown, token timeseries, and a
  zero-token-success data-quality signal
- Add TokenMonitoring dashboard tab (summary tiles, stacked token chart,
  per-model table) + i18n (en/zh)
- Regression tests for stream usage capture and usage normalization

Verified end-to-end against a real OpenAI-compatible endpoint with
gpt-5.5 and claude-opus-4-8: tokens now recorded non-zero for both
streaming and non-streaming paths.

* refactor(provider): simplify litellm capabilities

* style: simplify wrapped expressions

* feat(models): persist context metadata

* fix(provider): handle dict embeddings and openai-compatible rerank in LiteLLMRequester

- invoke_embedding: support both object- and dict-shaped response.data
  entries (OpenAI-compatible gateways like new-api return dicts)
- invoke_rerank: litellm.arerank rejects the 'openai' provider, so for
  openai-compatible (or unspecified) providers call the standard
  Jina/Cohere-style POST /v1/rerank endpoint directly over HTTP
- accept both 'relevance_score' and 'score' fields in rerank results
- add unit tests for the openai-compatible HTTP rerank path

* feat(provider): enforce requester support_type when adding models

- frontend: AddModelPopover only shows model-type tabs (llm/embedding/
  rerank) that the provider's requester declares in its manifest
  support_type; ModelsDialog fetches requester manifests and maps
  requester -> support_type, passed down through ProviderCard
- backend: add _validate_provider_supports guard in create_llm_model /
  create_embedding_model / create_rerank_model so a model cannot be
  attached to a provider whose requester does not support that type,
  even if the frontend restriction is bypassed (manifests without
  support_type are allowed for backward compatibility)
- manifests: correct support_type for providers that do not offer all
  three model types:
  - llm only: anthropic, deepseek, groq, moonshot, openrouter, xai
  - llm + text-embedding: openai, gemini, mistral
  - add rerank to new-api (verified working via /v1/rerank)
  - set llm + text-embedding + rerank for aggregator/unknown gateways

* feat(provider): add searchable alias to requester manifests

- add a free-text 'alias' field to every requester manifest spec,
  containing the vendor's English/Chinese names, pinyin, common
  nicknames and flagship model-series names (e.g. moonshot -> kimi,
  月之暗面; zhipu -> glm, 智谱清言)
- frontend: ProviderForm requester search now also matches against
  alias (substring/contains), so searching 'kimi' surfaces Moonshot,
  '硅基' surfaces SiliconFlow, etc.
- also fix support_type: openrouter (relay) supports embedding+rerank;
  LangBot Space gains rerank (coming soon)

* fix(provider): make support_type guard defensive against incomplete model_mgr

- _validate_provider_supports now uses getattr to gracefully skip when
  model_mgr / provider_dict / manifest lookup is unavailable, instead of
  raising AttributeError (fixes unit tests that mock ap.model_mgr as a
  bare SimpleNamespace)
- add TestValidateProviderSupports covering: allow supported type,
  reject unsupported type, allow when support_type missing, allow when
  provider unknown, degrade safely when model_mgr is incomplete

* fix(persistence): guard 0004 migration against missing llm_models table

The 0004_add_llm_model_context_length migration called
inspector.get_columns('llm_models') unconditionally, raising
NoSuchTableError when the table does not exist (e.g. migrating a
fresh/empty DB, as exercised by the integration tests where
create_all() registers no tables because the ORM models are not
imported). Every other migration guards with a table-existence check
first; add the same guard here for both upgrade and downgrade.

Also restore the test head assertion to 0004 (it had been lowered to
0003 to mask this failure).

* Merge branch 'master' into feat/litellm

Resolve conflicts:
- uv.lock: regenerated via 'uv lock' to reconcile litellm/fastuuid
  (ours) with openai bump (master).
- Alembic migrations: master added 0004_add_mcp_readme while this
  branch added 0004_add_llm_model_context_length, both as children of
  0003 (would create multiple heads). Re-chain the litellm migration as
  0005_add_llm_model_context_length with down_revision=0004_add_mcp_readme
  for a single linear head. Update test head assertion accordingly.

* fix(persistence): shorten migration revision id to fit varchar(32)

PostgreSQL stores alembic_version.version_num as varchar(32).
'0005_add_llm_model_context_length' (33 chars) overflowed it, raising
StringDataRightTruncationError in the PG migration tests. Rename the
revision (and file) to '0005_add_llm_context_length' (27 chars) and
update the head assertions in both SQLite and PostgreSQL migration
tests.

---------

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: fdc310 <2213070223@qq.com>
Co-authored-by: RockChinQ <rockchinq@gmail.com>
2026-06-13 16:59:48 +08:00

734 lines
29 KiB
Python

"""LiteLLM unified requester for chat, embedding, and rerank."""
from __future__ import annotations
import typing
import litellm
from litellm import acompletion, aembedding, arerank
from .. import errors, requester
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
class LiteLLMRequester(requester.ProviderAPIRequester):
"""LiteLLM unified API requester supporting chat, embedding, and rerank."""
_EMBEDDING_MODEL_HINTS = ('embedding', 'embed', 'bge-', 'e5-', 'm3e', 'gte-', 'text-embedding')
_RERANK_MODEL_HINTS = ('rerank', 're-rank', 're_rank')
default_config: dict[str, typing.Any] = {
'base_url': '',
'timeout': 120,
'custom_llm_provider': '',
'drop_params': False,
'num_retries': 0,
'api_version': '',
}
async def initialize(self):
"""Initialize LiteLLM client settings."""
# LiteLLM doesn't require explicit client initialization
# Configuration is passed per-request via litellm params
pass
def _build_litellm_model_name(self, model_name: str, custom_llm_provider: str | None = None) -> str:
"""Build LiteLLM model name with provider prefix if needed."""
provider = custom_llm_provider or self.requester_cfg.get('custom_llm_provider', '')
if provider:
# LiteLLM format: provider/model_name
if model_name.startswith(f'{provider}/'):
return model_name
return f'{provider}/{model_name}'
# If no custom provider, assume model_name already includes prefix or is OpenAI-compatible
return model_name
def _get_custom_llm_provider(self) -> str | None:
return self.requester_cfg.get('custom_llm_provider') or None
def _safe_litellm_bool_helper(self, helper_name: str, model_name: str) -> bool:
"""Call a LiteLLM boolean capability helper without letting metadata gaps fail requests."""
helper = getattr(litellm, helper_name, None)
if not callable(helper):
return False
provider = self._get_custom_llm_provider()
candidates: list[tuple[str, str | None]] = [(model_name, provider)]
litellm_model_name = self._build_litellm_model_name(model_name)
if litellm_model_name != model_name:
candidates.append((litellm_model_name, None))
for metadata_provider in self._metadata_provider_candidates(model_name):
candidates.append((f'{metadata_provider}/{model_name}', None))
tried_candidates: set[tuple[str, str | None]] = set()
for candidate_model, candidate_provider in candidates:
candidate_key = (candidate_model, candidate_provider)
if candidate_key in tried_candidates:
continue
tried_candidates.add(candidate_key)
try:
if bool(helper(model=candidate_model, custom_llm_provider=candidate_provider)):
return True
except Exception:
continue
return False
def _context_length_from_scan_payload(self, model_payload: dict[str, typing.Any] | None) -> int | None:
if not model_payload:
return None
for field_name in ('context_length', 'context_window', 'max_context_length'):
value = model_payload.get(field_name)
if isinstance(value, bool):
continue
if isinstance(value, int) and value > 0:
return value
if isinstance(value, str) and value.isdigit():
parsed_value = int(value)
if parsed_value > 0:
return parsed_value
return None
def _metadata_provider_candidates(self, model_name: str) -> list[str]:
normalized_model_name = (model_name or '').lower()
candidates = []
if normalized_model_name.startswith(('moonshot-', 'kimi-')):
candidates.append('moonshot')
if normalized_model_name.startswith('deepseek-'):
candidates.append('deepseek')
base_url = self.requester_cfg.get('base_url', '').lower()
if 'moonshot' in base_url:
candidates.append('moonshot')
if 'deepseek' in base_url:
candidates.append('deepseek')
deduped_candidates = []
for candidate in candidates:
if candidate not in deduped_candidates:
deduped_candidates.append(candidate)
return deduped_candidates
def _known_context_length_fallback(self, model_name: str) -> int | None:
normalized_model_name = (model_name or '').lower()
if normalized_model_name.startswith('deepseek-v4-'):
return 1_000_000
if normalized_model_name.startswith(('kimi-k2.5', 'kimi-k2.6')):
return 256 * 1024
if normalized_model_name.startswith('moonshot-v1-8k'):
return 8 * 1024
if normalized_model_name.startswith('moonshot-v1-32k'):
return 32 * 1024
if normalized_model_name.startswith('moonshot-v1-128k') or normalized_model_name == 'moonshot-v1-auto':
return 128 * 1024
return None
def _safe_context_length(self, model_name: str) -> int | None:
helper = getattr(litellm, 'get_max_tokens', None)
if not callable(helper):
return self._known_context_length_fallback(model_name)
candidates = [model_name]
litellm_model_name = self._build_litellm_model_name(model_name)
if litellm_model_name != model_name:
candidates.append(litellm_model_name)
for provider in self._metadata_provider_candidates(model_name):
candidates.append(f'{provider}/{model_name}')
tried_candidates = []
for candidate in candidates:
if candidate in tried_candidates:
continue
tried_candidates.append(candidate)
try:
max_tokens = helper(candidate)
except Exception:
continue
if isinstance(max_tokens, int) and max_tokens > 0:
return max_tokens
return self._known_context_length_fallback(model_name)
def _supports_function_calling(self, model_name: str) -> bool:
return self._safe_litellm_bool_helper('supports_function_calling', model_name)
def _supports_vision(self, model_name: str) -> bool:
return self._safe_litellm_bool_helper('supports_vision', model_name)
def _infer_model_type(self, model_id: str) -> str:
normalized_id = (model_id or '').lower()
if any(kw in normalized_id for kw in self._RERANK_MODEL_HINTS):
return 'rerank'
if any(kw in normalized_id for kw in self._EMBEDDING_MODEL_HINTS):
return 'embedding'
return 'llm'
def _enrich_scanned_model(
self,
model_id: str,
model_payload: dict[str, typing.Any] | None = None,
) -> dict[str, typing.Any]:
model_type = self._infer_model_type(model_id)
scanned_model: dict[str, typing.Any] = {
'id': model_id,
'name': model_id,
'type': model_type,
}
if model_type == 'llm':
abilities = []
if self._supports_function_calling(model_id):
abilities.append('func_call')
supports_provider_reported_vision = bool(
model_payload
and (model_payload.get('supports_image_in') is True or model_payload.get('supports_vision') is True)
)
if supports_provider_reported_vision or self._supports_vision(model_id):
abilities.append('vision')
scanned_model['abilities'] = abilities
context_length = self._context_length_from_scan_payload(model_payload)
if context_length is None:
context_length = self._safe_context_length(model_id)
if context_length is not None:
scanned_model['context_length'] = context_length
return scanned_model
def _convert_messages(self, messages: typing.List[provider_message.Message]) -> list[dict]:
"""Convert LangBot messages to LiteLLM/OpenAI format."""
req_messages = []
for m in messages:
msg_dict = m.dict(exclude_none=True)
content = msg_dict.get('content')
if isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get('type') == 'image_base64':
part['image_url'] = {'url': part['image_base64']}
part['type'] = 'image_url'
del part['image_base64']
req_messages.append(msg_dict)
return req_messages
def _process_thinking_content(self, content: str, reasoning_content: str | None, remove_think: bool) -> str:
"""Process thinking/reasoning content.
Args:
content: The main content from response
reasoning_content: Separate reasoning content from model
remove_think: If True, remove thinking markers; if False, preserve them
Returns:
Processed content string
"""
# Extract and handle thinking tags
if content and 'CRETIRE_REASONING_BEGINk' in content and 'CRETIRE_REASONING_ENDk' in content:
import re
think_pattern = r'CRETIRE_REASONING_BEGINk(.*?)CRETIRE_REASONING_ENDk'
if remove_think:
# Remove thinking tags and their content from output
content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip()
# else: preserve thinking content as-is
# Handle separate reasoning_content field
# Currently we don't include reasoning_content in user-facing output regardless of remove_think
# because it's typically internal model reasoning, not user-visible thinking
return content or ''
@staticmethod
def _normalize_usage(usage: typing.Any) -> dict:
"""Normalize a LiteLLM/OpenAI usage object into a plain token dict.
Handles several real-world shapes returned by different upstreams:
- object with ``prompt_tokens`` / ``completion_tokens`` / ``total_tokens`` attrs
- dict with the same keys
- missing ``total_tokens`` (derived from prompt + completion)
- ``None`` / partially-populated usage (defaults to 0)
"""
if usage is None:
return {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0}
def _get(key: str) -> typing.Any:
if isinstance(usage, dict):
return usage.get(key)
return getattr(usage, key, None)
prompt_tokens = _get('prompt_tokens') or 0
completion_tokens = _get('completion_tokens') or 0
total_tokens = _get('total_tokens') or 0
# Some providers omit total_tokens in streaming usage; derive it.
if not total_tokens:
total_tokens = prompt_tokens + completion_tokens
return {
'prompt_tokens': int(prompt_tokens),
'completion_tokens': int(completion_tokens),
'total_tokens': int(total_tokens),
}
def _extract_usage(self, response) -> dict:
"""Extract usage info from a non-streaming LiteLLM response."""
return self._normalize_usage(getattr(response, 'usage', None))
@staticmethod
def _as_dict(value: typing.Any) -> dict:
if value is None:
return {}
if isinstance(value, dict):
return value
if hasattr(value, 'model_dump'):
return value.model_dump()
return {}
def _normalize_stream_tool_calls(
self,
raw_tool_calls: typing.Any,
tool_call_state: dict[int, dict[str, str]],
) -> list[dict] | None:
"""Fill OpenAI-style streaming tool-call deltas so MessageChunk can validate them."""
if not raw_tool_calls:
return None
normalized = []
for fallback_index, raw_tool_call in enumerate(raw_tool_calls):
tool_call = self._as_dict(raw_tool_call)
index = tool_call.get('index')
if not isinstance(index, int):
index = fallback_index
state = tool_call_state.setdefault(index, {'id': '', 'type': 'function', 'name': ''})
if tool_call.get('id'):
state['id'] = tool_call['id']
if tool_call.get('type'):
state['type'] = tool_call['type']
function = self._as_dict(tool_call.get('function'))
if function.get('name'):
state['name'] = function['name']
arguments = function.get('arguments')
if arguments is None:
arguments = ''
elif not isinstance(arguments, str):
arguments = str(arguments)
if not state['id'] or not state['name']:
continue
normalized.append(
{
'id': state['id'],
'type': state['type'] or 'function',
'function': {
'name': state['name'],
'arguments': arguments,
},
}
)
return normalized or None
def _build_common_args(self, args: dict, include_retry_params: bool = True) -> dict:
"""Apply common requester config to args dict."""
if self.requester_cfg.get('base_url'):
args['api_base'] = self.requester_cfg['base_url']
if self.requester_cfg.get('timeout'):
args['timeout'] = self.requester_cfg['timeout']
if include_retry_params:
if self.requester_cfg.get('drop_params'):
args['drop_params'] = self.requester_cfg['drop_params']
if self.requester_cfg.get('num_retries'):
args['num_retries'] = self.requester_cfg['num_retries']
if self.requester_cfg.get('api_version'):
args['api_version'] = self.requester_cfg['api_version']
return args
def _handle_litellm_error(self, e: Exception) -> None:
"""Convert LiteLLM exceptions to RequesterError. Never returns, always raises."""
# Check more specific exceptions first (they inherit from base exceptions)
if isinstance(e, litellm.ContextWindowExceededError):
raise errors.RequesterError(f'上下文长度超限: {str(e)}')
if isinstance(e, litellm.BadRequestError):
raise errors.RequesterError(f'请求参数错误: {str(e)}')
if isinstance(e, litellm.AuthenticationError):
raise errors.RequesterError(f'API key 无效: {str(e)}')
if isinstance(e, litellm.NotFoundError):
raise errors.RequesterError(f'模型或路径无效: {str(e)}')
if isinstance(e, litellm.RateLimitError):
raise errors.RequesterError(f'请求过于频繁或余额不足: {str(e)}')
if isinstance(e, litellm.Timeout):
raise errors.RequesterError(f'请求超时: {str(e)}')
if isinstance(e, litellm.APIConnectionError):
raise errors.RequesterError(f'连接错误: {str(e)}')
if isinstance(e, litellm.APIError):
raise errors.RequesterError(f'API 错误: {str(e)}')
raise errors.RequesterError(f'未知错误: {str(e)}')
async def _build_completion_args(
self,
model: requester.RuntimeLLMModel,
messages: typing.List[provider_message.Message],
funcs: typing.List[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
stream: bool = False,
) -> dict:
"""Build common completion arguments for invoke_llm and invoke_llm_stream."""
req_messages = self._convert_messages(messages)
model_name = self._build_litellm_model_name(model.model_entity.name)
api_key = model.provider.token_mgr.get_token()
args = {
'model': model_name,
'messages': req_messages,
'api_key': api_key,
}
if stream:
args['stream'] = True
args['stream_options'] = {'include_usage': True}
self._build_common_args(args)
# Apply model-level extra_args first, then call-level extra_args
if model.model_entity.extra_args:
args.update(model.model_entity.extra_args)
args.update(extra_args)
if funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
if tools:
args['tools'] = tools
args.setdefault('tool_choice', 'auto')
return args
async def invoke_llm(
self,
query: pipeline_query.Query,
model: requester.RuntimeLLMModel,
messages: typing.List[provider_message.Message],
funcs: typing.List[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
remove_think: bool = False,
) -> tuple[provider_message.Message, dict]:
"""Invoke LLM and return message with usage info."""
args = await self._build_completion_args(model, messages, funcs, extra_args, stream=False)
try:
response = await acompletion(**args)
message_data = response.choices[0].message.model_dump()
if 'role' not in message_data or message_data['role'] is None:
message_data['role'] = 'assistant'
content = message_data.get('content', '')
reasoning_content = message_data.get('reasoning_content', None)
message_data['content'] = self._process_thinking_content(content, reasoning_content, remove_think)
if 'reasoning_content' in message_data:
del message_data['reasoning_content']
message = provider_message.Message(**message_data)
usage_info = self._extract_usage(response)
return message, usage_info
except Exception as e:
self._handle_litellm_error(e)
async def invoke_llm_stream(
self,
query: pipeline_query.Query,
model: requester.RuntimeLLMModel,
messages: typing.List[provider_message.Message],
funcs: typing.List[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
remove_think: bool = False,
) -> provider_message.MessageChunk:
"""Invoke LLM streaming and yield chunks."""
args = await self._build_completion_args(model, messages, funcs, extra_args, stream=True)
chunk_idx = 0
role = 'assistant'
tool_call_state: dict[int, dict[str, str]] = {}
try:
response = await acompletion(**args)
async for chunk in response:
# Capture usage whenever a chunk carries it.
#
# Important: many OpenAI-compatible gateways (e.g. new-api) and
# providers send the final usage payload in a chunk that STILL
# contains a (empty-delta) choice, not an empty `choices` list.
# The previous implementation only captured usage when `choices`
# was empty, so streamed calls always recorded 0 tokens.
# We therefore capture usage independently of `choices`, and then
# fall through to also process any content this chunk may carry.
if getattr(chunk, 'usage', None):
usage_info = self._normalize_usage(chunk.usage)
if query is not None:
if query.variables is None:
query.variables = {}
query.variables['_stream_usage'] = usage_info
if not hasattr(chunk, 'choices') or not chunk.choices:
continue
choice = chunk.choices[0]
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
finish_reason = getattr(choice, 'finish_reason', None)
if 'role' in delta and delta['role']:
role = delta['role']
delta_content = delta.get('content', '')
reasoning_content = delta.get('reasoning_content', '')
# Handle reasoning_content based on remove_think flag
if reasoning_content:
if remove_think:
# Skip reasoning content when remove_think is True
chunk_idx += 1
continue
else:
# Use reasoning_content as the displayed content
delta_content = reasoning_content
tool_calls = self._normalize_stream_tool_calls(delta.get('tool_calls'), tool_call_state)
if chunk_idx == 0 and not delta_content and not tool_calls:
chunk_idx += 1
continue
chunk_data = {
'role': role,
'content': delta_content if delta_content else None,
'tool_calls': tool_calls,
'is_final': bool(finish_reason),
}
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
yield provider_message.MessageChunk(**chunk_data)
chunk_idx += 1
except Exception as e:
self._handle_litellm_error(e)
async def invoke_embedding(
self,
model: requester.RuntimeEmbeddingModel,
input_text: list[str],
extra_args: dict[str, typing.Any] = {},
) -> tuple[list[list[float]], dict]:
"""Invoke embedding and return vectors with usage info."""
model_name = self._build_litellm_model_name(model.model_entity.name)
api_key = model.provider.token_mgr.get_token()
args = {
'model': model_name,
'input': input_text,
'api_key': api_key,
}
self._build_common_args(args, include_retry_params=False)
if model.model_entity.extra_args:
args.update(model.model_entity.extra_args)
args.update(extra_args)
try:
response = await aembedding(**args)
# LiteLLM returns response.data entries either as objects with an
# `.embedding` attribute or as plain dicts (many OpenAI-compatible
# gateways, e.g. new-api, yield dict-shaped entries). Handle both.
embeddings = [d['embedding'] if isinstance(d, dict) else d.embedding for d in response.data]
usage_info = self._extract_usage(response)
return embeddings, usage_info
except Exception as e:
self._handle_litellm_error(e)
async def invoke_rerank(
self,
model: requester.RuntimeRerankModel,
query: str,
documents: typing.List[str],
extra_args: dict[str, typing.Any] = {},
) -> typing.List[dict]:
"""Invoke rerank and return relevance scores."""
model_name = self._build_litellm_model_name(model.model_entity.name)
api_key = model.provider.token_mgr.get_token()
top_n = min(len(documents), 64)
provider = self._get_custom_llm_provider()
try:
# LiteLLM's rerank API does not support the `openai` provider
# (litellm/rerank_api/main.py raises "Unsupported provider: openai").
# OpenAI-compatible gateways (newapi / one-api / vLLM / Xinference, etc.)
# expose the standard Jina/Cohere-style POST /v1/rerank endpoint, so
# call it directly over HTTP for openai-compatible (or unspecified) providers.
if provider in (None, '', 'openai'):
results = await self._invoke_rerank_openai_compatible(
model_name=model.model_entity.name,
query=query,
documents=documents,
api_key=api_key,
top_n=top_n,
extra_args={**(model.model_entity.extra_args or {}), **extra_args},
)
else:
args = {
'model': model_name,
'query': query,
'documents': documents,
'api_key': api_key,
'top_n': top_n,
}
self._build_common_args(args, include_retry_params=False)
if model.model_entity.extra_args:
args.update(model.model_entity.extra_args)
args.update(extra_args)
response = await arerank(**args)
results = []
for r in response.results:
results.append(
{
'index': r.get('index', 0),
'relevance_score': r.get('relevance_score', 0.0),
}
)
if results:
scores = [r['relevance_score'] for r in results]
min_score = min(scores)
max_score = max(scores)
if max_score - min_score > 1e-6:
for r in results:
r['relevance_score'] = (r['relevance_score'] - min_score) / (max_score - min_score)
return results
except errors.RequesterError:
raise
except Exception as e:
self._handle_litellm_error(e)
async def _invoke_rerank_openai_compatible(
self,
model_name: str,
query: str,
documents: typing.List[str],
api_key: str,
top_n: int,
extra_args: dict[str, typing.Any] = {},
) -> typing.List[dict]:
"""Call the standard Jina/Cohere-style POST /v1/rerank endpoint over HTTP.
Used for OpenAI-compatible gateways where litellm.arerank rejects the
`openai` provider. Returns the same shape as the litellm path:
a list of {'index': int, 'relevance_score': float}.
"""
import httpx
base_url = (self.requester_cfg.get('base_url') or '').rstrip('/')
if not base_url:
raise errors.RequesterError('Base URL required for rerank')
timeout = self.requester_cfg.get('timeout', 120)
headers = {'Content-Type': 'application/json'}
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
payload: dict[str, typing.Any] = {
'model': model_name,
'query': query,
'documents': documents,
'top_n': top_n,
}
if extra_args:
payload.update(extra_args)
rerank_url = f'{base_url}/rerank'
try:
async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.post(rerank_url, headers=headers, json=payload)
resp.raise_for_status()
data = resp.json()
except httpx.HTTPStatusError as e:
body = ''
try:
body = e.response.text
except Exception:
pass
raise errors.RequesterError(f'rerank 请求失败 (HTTP {e.response.status_code}): {body or str(e)}')
except httpx.HTTPError as e:
raise errors.RequesterError(f'rerank 连接错误: {str(e)}')
raw_results = data.get('results', []) if isinstance(data, dict) else []
results = []
for r in raw_results:
results.append(
{
'index': r.get('index', 0),
'relevance_score': r.get('relevance_score', r.get('score', 0.0)) or 0.0,
}
)
return results
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
"""Scan models supported by the provider."""
import httpx
base_url = self.requester_cfg.get('base_url', '').rstrip('/')
timeout = self.requester_cfg.get('timeout', 120)
if not base_url:
raise errors.RequesterError('Base URL required for model scanning')
headers = {}
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
models_url = f'{base_url}/models'
try:
async with httpx.AsyncClient(trust_env=True, timeout=timeout) as client:
response = await client.get(models_url, headers=headers)
response.raise_for_status()
payload = response.json()
models = []
for item in payload.get('data', []):
model_id = item.get('id')
if not model_id:
continue
models.append(self._enrich_scanned_model(model_id, item))
models.sort(key=lambda x: (x['type'] != 'llm', x['name'].lower()))
return {'models': models}
except httpx.HTTPStatusError as e:
raise errors.RequesterError(f'Model scan failed: {e.response.status_code}')
except httpx.TimeoutException:
raise errors.RequesterError('Model scan timeout')
except Exception as e:
raise errors.RequesterError(f'Model scan error: {str(e)}')