mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-14 17:56:03 +00:00
* refactor(provider): use LiteLLM as unified LLM requester backend
- Replace 23+ individual requester implementations with unified litellmchat.py
- Add litellm_provider field to 27 YAML manifests for provider routing
- Delete redundant requester subclasses
- Add unit tests for LiteLLMRequester (29 tests)
- Fix num_retries parameter name (was max_retries)
- Fix exception handling order for subclass exceptions
LiteLLM provides unified API for 100+ providers, eliminating need for
provider-specific requesters.
* fix: ruff format provider.py
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
* refactor(provider): simplify LiteLLM requester usage handling
- Remove unused Anthropic-specific tool schema generation
- Share completion argument construction between normal and streaming calls
- Use LiteLLM/OpenAI native usage fields for monitoring
- Collect stream token usage from LiteLLM stream_options
- Update LiteLLM requester tests for unified usage fields
* restore: restore deleted provider requester files
Restore individual provider requester implementations that were
removed in de61b5d3. These files coexist with the unified
litellmchat.py backend.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
* feat: update requesters and improve provider selection UI
- Added `litellm_provider` field to various requesters' YAML configurations.
- Removed obsolete Python requester files for OpenRouter, PPIO, QHAIGC, ShengSuanYun, SiliconFlow, Space, TokenPony, VolcArk, and Xai.
- Introduced new requesters for Tencent and Together AI with corresponding YAML configurations and SVG icons.
- Enhanced the ProviderForm component to include a searchable dropdown for selecting providers, improving user experience.
- Updated localization files to include search provider text for both English and Chinese.
* fix(provider): align litellm rebase with master
* fix(provider): capture streaming token usage; add token observability
The LiteLLM streaming requester only captured usage when a chunk had an
empty `choices` list. Many OpenAI-compatible gateways (e.g. new-api) and
providers send the final usage payload in a chunk that still carries an
empty-delta choice, so streamed calls always recorded 0 tokens in the
monitoring logs/dashboard (non-streaming worked).
- Capture stream usage whenever a chunk carries it, regardless of choices
- Add robust _normalize_usage (dict/obj shapes, derive missing total_tokens)
- Register litellm in bootutils/deps.py (was in pyproject only)
- Add MonitoringService.get_token_statistics + /monitoring/token-statistics
endpoint: summary, per-model breakdown, token timeseries, and a
zero-token-success data-quality signal
- Add TokenMonitoring dashboard tab (summary tiles, stacked token chart,
per-model table) + i18n (en/zh)
- Regression tests for stream usage capture and usage normalization
Verified end-to-end against a real OpenAI-compatible endpoint with
gpt-5.5 and claude-opus-4-8: tokens now recorded non-zero for both
streaming and non-streaming paths.
* refactor(provider): simplify litellm capabilities
* style: simplify wrapped expressions
* feat(models): persist context metadata
* fix(provider): handle dict embeddings and openai-compatible rerank in LiteLLMRequester
- invoke_embedding: support both object- and dict-shaped response.data
entries (OpenAI-compatible gateways like new-api return dicts)
- invoke_rerank: litellm.arerank rejects the 'openai' provider, so for
openai-compatible (or unspecified) providers call the standard
Jina/Cohere-style POST /v1/rerank endpoint directly over HTTP
- accept both 'relevance_score' and 'score' fields in rerank results
- add unit tests for the openai-compatible HTTP rerank path
* feat(provider): enforce requester support_type when adding models
- frontend: AddModelPopover only shows model-type tabs (llm/embedding/
rerank) that the provider's requester declares in its manifest
support_type; ModelsDialog fetches requester manifests and maps
requester -> support_type, passed down through ProviderCard
- backend: add _validate_provider_supports guard in create_llm_model /
create_embedding_model / create_rerank_model so a model cannot be
attached to a provider whose requester does not support that type,
even if the frontend restriction is bypassed (manifests without
support_type are allowed for backward compatibility)
- manifests: correct support_type for providers that do not offer all
three model types:
- llm only: anthropic, deepseek, groq, moonshot, openrouter, xai
- llm + text-embedding: openai, gemini, mistral
- add rerank to new-api (verified working via /v1/rerank)
- set llm + text-embedding + rerank for aggregator/unknown gateways
* feat(provider): add searchable alias to requester manifests
- add a free-text 'alias' field to every requester manifest spec,
containing the vendor's English/Chinese names, pinyin, common
nicknames and flagship model-series names (e.g. moonshot -> kimi,
月之暗面; zhipu -> glm, 智谱清言)
- frontend: ProviderForm requester search now also matches against
alias (substring/contains), so searching 'kimi' surfaces Moonshot,
'硅基' surfaces SiliconFlow, etc.
- also fix support_type: openrouter (relay) supports embedding+rerank;
LangBot Space gains rerank (coming soon)
* fix(provider): make support_type guard defensive against incomplete model_mgr
- _validate_provider_supports now uses getattr to gracefully skip when
model_mgr / provider_dict / manifest lookup is unavailable, instead of
raising AttributeError (fixes unit tests that mock ap.model_mgr as a
bare SimpleNamespace)
- add TestValidateProviderSupports covering: allow supported type,
reject unsupported type, allow when support_type missing, allow when
provider unknown, degrade safely when model_mgr is incomplete
* fix(persistence): guard 0004 migration against missing llm_models table
The 0004_add_llm_model_context_length migration called
inspector.get_columns('llm_models') unconditionally, raising
NoSuchTableError when the table does not exist (e.g. migrating a
fresh/empty DB, as exercised by the integration tests where
create_all() registers no tables because the ORM models are not
imported). Every other migration guards with a table-existence check
first; add the same guard here for both upgrade and downgrade.
Also restore the test head assertion to 0004 (it had been lowered to
0003 to mask this failure).
* Merge branch 'master' into feat/litellm
Resolve conflicts:
- uv.lock: regenerated via 'uv lock' to reconcile litellm/fastuuid
(ours) with openai bump (master).
- Alembic migrations: master added 0004_add_mcp_readme while this
branch added 0004_add_llm_model_context_length, both as children of
0003 (would create multiple heads). Re-chain the litellm migration as
0005_add_llm_model_context_length with down_revision=0004_add_mcp_readme
for a single linear head. Update test head assertion accordingly.
* fix(persistence): shorten migration revision id to fit varchar(32)
PostgreSQL stores alembic_version.version_num as varchar(32).
'0005_add_llm_model_context_length' (33 chars) overflowed it, raising
StringDataRightTruncationError in the PG migration tests. Rename the
revision (and file) to '0005_add_llm_context_length' (27 chars) and
update the head assertions in both SQLite and PostgreSQL migration
tests.
---------
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: fdc310 <2213070223@qq.com>
Co-authored-by: RockChinQ <rockchinq@gmail.com>
734 lines
29 KiB
Python
734 lines
29 KiB
Python
"""LiteLLM unified requester for chat, embedding, and rerank."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import typing
|
|
|
|
import litellm
|
|
from litellm import acompletion, aembedding, arerank
|
|
|
|
from .. import errors, requester
|
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
|
|
|
|
|
class LiteLLMRequester(requester.ProviderAPIRequester):
|
|
"""LiteLLM unified API requester supporting chat, embedding, and rerank."""
|
|
|
|
_EMBEDDING_MODEL_HINTS = ('embedding', 'embed', 'bge-', 'e5-', 'm3e', 'gte-', 'text-embedding')
|
|
_RERANK_MODEL_HINTS = ('rerank', 're-rank', 're_rank')
|
|
|
|
default_config: dict[str, typing.Any] = {
|
|
'base_url': '',
|
|
'timeout': 120,
|
|
'custom_llm_provider': '',
|
|
'drop_params': False,
|
|
'num_retries': 0,
|
|
'api_version': '',
|
|
}
|
|
|
|
async def initialize(self):
|
|
"""Initialize LiteLLM client settings."""
|
|
# LiteLLM doesn't require explicit client initialization
|
|
# Configuration is passed per-request via litellm params
|
|
pass
|
|
|
|
def _build_litellm_model_name(self, model_name: str, custom_llm_provider: str | None = None) -> str:
|
|
"""Build LiteLLM model name with provider prefix if needed."""
|
|
provider = custom_llm_provider or self.requester_cfg.get('custom_llm_provider', '')
|
|
if provider:
|
|
# LiteLLM format: provider/model_name
|
|
if model_name.startswith(f'{provider}/'):
|
|
return model_name
|
|
return f'{provider}/{model_name}'
|
|
# If no custom provider, assume model_name already includes prefix or is OpenAI-compatible
|
|
return model_name
|
|
|
|
def _get_custom_llm_provider(self) -> str | None:
|
|
return self.requester_cfg.get('custom_llm_provider') or None
|
|
|
|
def _safe_litellm_bool_helper(self, helper_name: str, model_name: str) -> bool:
|
|
"""Call a LiteLLM boolean capability helper without letting metadata gaps fail requests."""
|
|
helper = getattr(litellm, helper_name, None)
|
|
if not callable(helper):
|
|
return False
|
|
|
|
provider = self._get_custom_llm_provider()
|
|
candidates: list[tuple[str, str | None]] = [(model_name, provider)]
|
|
litellm_model_name = self._build_litellm_model_name(model_name)
|
|
if litellm_model_name != model_name:
|
|
candidates.append((litellm_model_name, None))
|
|
for metadata_provider in self._metadata_provider_candidates(model_name):
|
|
candidates.append((f'{metadata_provider}/{model_name}', None))
|
|
|
|
tried_candidates: set[tuple[str, str | None]] = set()
|
|
for candidate_model, candidate_provider in candidates:
|
|
candidate_key = (candidate_model, candidate_provider)
|
|
if candidate_key in tried_candidates:
|
|
continue
|
|
tried_candidates.add(candidate_key)
|
|
try:
|
|
if bool(helper(model=candidate_model, custom_llm_provider=candidate_provider)):
|
|
return True
|
|
except Exception:
|
|
continue
|
|
return False
|
|
|
|
def _context_length_from_scan_payload(self, model_payload: dict[str, typing.Any] | None) -> int | None:
|
|
if not model_payload:
|
|
return None
|
|
|
|
for field_name in ('context_length', 'context_window', 'max_context_length'):
|
|
value = model_payload.get(field_name)
|
|
if isinstance(value, bool):
|
|
continue
|
|
if isinstance(value, int) and value > 0:
|
|
return value
|
|
if isinstance(value, str) and value.isdigit():
|
|
parsed_value = int(value)
|
|
if parsed_value > 0:
|
|
return parsed_value
|
|
return None
|
|
|
|
def _metadata_provider_candidates(self, model_name: str) -> list[str]:
|
|
normalized_model_name = (model_name or '').lower()
|
|
candidates = []
|
|
if normalized_model_name.startswith(('moonshot-', 'kimi-')):
|
|
candidates.append('moonshot')
|
|
if normalized_model_name.startswith('deepseek-'):
|
|
candidates.append('deepseek')
|
|
|
|
base_url = self.requester_cfg.get('base_url', '').lower()
|
|
if 'moonshot' in base_url:
|
|
candidates.append('moonshot')
|
|
if 'deepseek' in base_url:
|
|
candidates.append('deepseek')
|
|
|
|
deduped_candidates = []
|
|
for candidate in candidates:
|
|
if candidate not in deduped_candidates:
|
|
deduped_candidates.append(candidate)
|
|
return deduped_candidates
|
|
|
|
def _known_context_length_fallback(self, model_name: str) -> int | None:
|
|
normalized_model_name = (model_name or '').lower()
|
|
if normalized_model_name.startswith('deepseek-v4-'):
|
|
return 1_000_000
|
|
if normalized_model_name.startswith(('kimi-k2.5', 'kimi-k2.6')):
|
|
return 256 * 1024
|
|
if normalized_model_name.startswith('moonshot-v1-8k'):
|
|
return 8 * 1024
|
|
if normalized_model_name.startswith('moonshot-v1-32k'):
|
|
return 32 * 1024
|
|
if normalized_model_name.startswith('moonshot-v1-128k') or normalized_model_name == 'moonshot-v1-auto':
|
|
return 128 * 1024
|
|
return None
|
|
|
|
def _safe_context_length(self, model_name: str) -> int | None:
|
|
helper = getattr(litellm, 'get_max_tokens', None)
|
|
if not callable(helper):
|
|
return self._known_context_length_fallback(model_name)
|
|
|
|
candidates = [model_name]
|
|
litellm_model_name = self._build_litellm_model_name(model_name)
|
|
if litellm_model_name != model_name:
|
|
candidates.append(litellm_model_name)
|
|
for provider in self._metadata_provider_candidates(model_name):
|
|
candidates.append(f'{provider}/{model_name}')
|
|
|
|
tried_candidates = []
|
|
for candidate in candidates:
|
|
if candidate in tried_candidates:
|
|
continue
|
|
tried_candidates.append(candidate)
|
|
try:
|
|
max_tokens = helper(candidate)
|
|
except Exception:
|
|
continue
|
|
if isinstance(max_tokens, int) and max_tokens > 0:
|
|
return max_tokens
|
|
return self._known_context_length_fallback(model_name)
|
|
|
|
def _supports_function_calling(self, model_name: str) -> bool:
|
|
return self._safe_litellm_bool_helper('supports_function_calling', model_name)
|
|
|
|
def _supports_vision(self, model_name: str) -> bool:
|
|
return self._safe_litellm_bool_helper('supports_vision', model_name)
|
|
|
|
def _infer_model_type(self, model_id: str) -> str:
|
|
normalized_id = (model_id or '').lower()
|
|
if any(kw in normalized_id for kw in self._RERANK_MODEL_HINTS):
|
|
return 'rerank'
|
|
if any(kw in normalized_id for kw in self._EMBEDDING_MODEL_HINTS):
|
|
return 'embedding'
|
|
return 'llm'
|
|
|
|
def _enrich_scanned_model(
|
|
self,
|
|
model_id: str,
|
|
model_payload: dict[str, typing.Any] | None = None,
|
|
) -> dict[str, typing.Any]:
|
|
model_type = self._infer_model_type(model_id)
|
|
scanned_model: dict[str, typing.Any] = {
|
|
'id': model_id,
|
|
'name': model_id,
|
|
'type': model_type,
|
|
}
|
|
|
|
if model_type == 'llm':
|
|
abilities = []
|
|
if self._supports_function_calling(model_id):
|
|
abilities.append('func_call')
|
|
supports_provider_reported_vision = bool(
|
|
model_payload
|
|
and (model_payload.get('supports_image_in') is True or model_payload.get('supports_vision') is True)
|
|
)
|
|
if supports_provider_reported_vision or self._supports_vision(model_id):
|
|
abilities.append('vision')
|
|
scanned_model['abilities'] = abilities
|
|
|
|
context_length = self._context_length_from_scan_payload(model_payload)
|
|
if context_length is None:
|
|
context_length = self._safe_context_length(model_id)
|
|
if context_length is not None:
|
|
scanned_model['context_length'] = context_length
|
|
|
|
return scanned_model
|
|
|
|
def _convert_messages(self, messages: typing.List[provider_message.Message]) -> list[dict]:
|
|
"""Convert LangBot messages to LiteLLM/OpenAI format."""
|
|
req_messages = []
|
|
for m in messages:
|
|
msg_dict = m.dict(exclude_none=True)
|
|
content = msg_dict.get('content')
|
|
|
|
if isinstance(content, list):
|
|
for part in content:
|
|
if isinstance(part, dict) and part.get('type') == 'image_base64':
|
|
part['image_url'] = {'url': part['image_base64']}
|
|
part['type'] = 'image_url'
|
|
del part['image_base64']
|
|
|
|
req_messages.append(msg_dict)
|
|
|
|
return req_messages
|
|
|
|
def _process_thinking_content(self, content: str, reasoning_content: str | None, remove_think: bool) -> str:
|
|
"""Process thinking/reasoning content.
|
|
|
|
Args:
|
|
content: The main content from response
|
|
reasoning_content: Separate reasoning content from model
|
|
remove_think: If True, remove thinking markers; if False, preserve them
|
|
|
|
Returns:
|
|
Processed content string
|
|
"""
|
|
# Extract and handle thinking tags
|
|
if content and 'CRETIRE_REASONING_BEGINk' in content and 'CRETIRE_REASONING_ENDk' in content:
|
|
import re
|
|
|
|
think_pattern = r'CRETIRE_REASONING_BEGINk(.*?)CRETIRE_REASONING_ENDk'
|
|
|
|
if remove_think:
|
|
# Remove thinking tags and their content from output
|
|
content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip()
|
|
# else: preserve thinking content as-is
|
|
|
|
# Handle separate reasoning_content field
|
|
# Currently we don't include reasoning_content in user-facing output regardless of remove_think
|
|
# because it's typically internal model reasoning, not user-visible thinking
|
|
return content or ''
|
|
|
|
@staticmethod
|
|
def _normalize_usage(usage: typing.Any) -> dict:
|
|
"""Normalize a LiteLLM/OpenAI usage object into a plain token dict.
|
|
|
|
Handles several real-world shapes returned by different upstreams:
|
|
- object with ``prompt_tokens`` / ``completion_tokens`` / ``total_tokens`` attrs
|
|
- dict with the same keys
|
|
- missing ``total_tokens`` (derived from prompt + completion)
|
|
- ``None`` / partially-populated usage (defaults to 0)
|
|
"""
|
|
if usage is None:
|
|
return {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0}
|
|
|
|
def _get(key: str) -> typing.Any:
|
|
if isinstance(usage, dict):
|
|
return usage.get(key)
|
|
return getattr(usage, key, None)
|
|
|
|
prompt_tokens = _get('prompt_tokens') or 0
|
|
completion_tokens = _get('completion_tokens') or 0
|
|
total_tokens = _get('total_tokens') or 0
|
|
|
|
# Some providers omit total_tokens in streaming usage; derive it.
|
|
if not total_tokens:
|
|
total_tokens = prompt_tokens + completion_tokens
|
|
|
|
return {
|
|
'prompt_tokens': int(prompt_tokens),
|
|
'completion_tokens': int(completion_tokens),
|
|
'total_tokens': int(total_tokens),
|
|
}
|
|
|
|
def _extract_usage(self, response) -> dict:
|
|
"""Extract usage info from a non-streaming LiteLLM response."""
|
|
return self._normalize_usage(getattr(response, 'usage', None))
|
|
|
|
@staticmethod
|
|
def _as_dict(value: typing.Any) -> dict:
|
|
if value is None:
|
|
return {}
|
|
if isinstance(value, dict):
|
|
return value
|
|
if hasattr(value, 'model_dump'):
|
|
return value.model_dump()
|
|
return {}
|
|
|
|
def _normalize_stream_tool_calls(
|
|
self,
|
|
raw_tool_calls: typing.Any,
|
|
tool_call_state: dict[int, dict[str, str]],
|
|
) -> list[dict] | None:
|
|
"""Fill OpenAI-style streaming tool-call deltas so MessageChunk can validate them."""
|
|
if not raw_tool_calls:
|
|
return None
|
|
|
|
normalized = []
|
|
for fallback_index, raw_tool_call in enumerate(raw_tool_calls):
|
|
tool_call = self._as_dict(raw_tool_call)
|
|
index = tool_call.get('index')
|
|
if not isinstance(index, int):
|
|
index = fallback_index
|
|
|
|
state = tool_call_state.setdefault(index, {'id': '', 'type': 'function', 'name': ''})
|
|
if tool_call.get('id'):
|
|
state['id'] = tool_call['id']
|
|
if tool_call.get('type'):
|
|
state['type'] = tool_call['type']
|
|
|
|
function = self._as_dict(tool_call.get('function'))
|
|
if function.get('name'):
|
|
state['name'] = function['name']
|
|
|
|
arguments = function.get('arguments')
|
|
if arguments is None:
|
|
arguments = ''
|
|
elif not isinstance(arguments, str):
|
|
arguments = str(arguments)
|
|
|
|
if not state['id'] or not state['name']:
|
|
continue
|
|
|
|
normalized.append(
|
|
{
|
|
'id': state['id'],
|
|
'type': state['type'] or 'function',
|
|
'function': {
|
|
'name': state['name'],
|
|
'arguments': arguments,
|
|
},
|
|
}
|
|
)
|
|
|
|
return normalized or None
|
|
|
|
def _build_common_args(self, args: dict, include_retry_params: bool = True) -> dict:
|
|
"""Apply common requester config to args dict."""
|
|
if self.requester_cfg.get('base_url'):
|
|
args['api_base'] = self.requester_cfg['base_url']
|
|
if self.requester_cfg.get('timeout'):
|
|
args['timeout'] = self.requester_cfg['timeout']
|
|
if include_retry_params:
|
|
if self.requester_cfg.get('drop_params'):
|
|
args['drop_params'] = self.requester_cfg['drop_params']
|
|
if self.requester_cfg.get('num_retries'):
|
|
args['num_retries'] = self.requester_cfg['num_retries']
|
|
if self.requester_cfg.get('api_version'):
|
|
args['api_version'] = self.requester_cfg['api_version']
|
|
return args
|
|
|
|
def _handle_litellm_error(self, e: Exception) -> None:
|
|
"""Convert LiteLLM exceptions to RequesterError. Never returns, always raises."""
|
|
# Check more specific exceptions first (they inherit from base exceptions)
|
|
if isinstance(e, litellm.ContextWindowExceededError):
|
|
raise errors.RequesterError(f'上下文长度超限: {str(e)}')
|
|
if isinstance(e, litellm.BadRequestError):
|
|
raise errors.RequesterError(f'请求参数错误: {str(e)}')
|
|
if isinstance(e, litellm.AuthenticationError):
|
|
raise errors.RequesterError(f'API key 无效: {str(e)}')
|
|
if isinstance(e, litellm.NotFoundError):
|
|
raise errors.RequesterError(f'模型或路径无效: {str(e)}')
|
|
if isinstance(e, litellm.RateLimitError):
|
|
raise errors.RequesterError(f'请求过于频繁或余额不足: {str(e)}')
|
|
if isinstance(e, litellm.Timeout):
|
|
raise errors.RequesterError(f'请求超时: {str(e)}')
|
|
if isinstance(e, litellm.APIConnectionError):
|
|
raise errors.RequesterError(f'连接错误: {str(e)}')
|
|
if isinstance(e, litellm.APIError):
|
|
raise errors.RequesterError(f'API 错误: {str(e)}')
|
|
raise errors.RequesterError(f'未知错误: {str(e)}')
|
|
|
|
async def _build_completion_args(
|
|
self,
|
|
model: requester.RuntimeLLMModel,
|
|
messages: typing.List[provider_message.Message],
|
|
funcs: typing.List[resource_tool.LLMTool] = None,
|
|
extra_args: dict[str, typing.Any] = {},
|
|
stream: bool = False,
|
|
) -> dict:
|
|
"""Build common completion arguments for invoke_llm and invoke_llm_stream."""
|
|
req_messages = self._convert_messages(messages)
|
|
model_name = self._build_litellm_model_name(model.model_entity.name)
|
|
api_key = model.provider.token_mgr.get_token()
|
|
|
|
args = {
|
|
'model': model_name,
|
|
'messages': req_messages,
|
|
'api_key': api_key,
|
|
}
|
|
if stream:
|
|
args['stream'] = True
|
|
args['stream_options'] = {'include_usage': True}
|
|
self._build_common_args(args)
|
|
|
|
# Apply model-level extra_args first, then call-level extra_args
|
|
if model.model_entity.extra_args:
|
|
args.update(model.model_entity.extra_args)
|
|
args.update(extra_args)
|
|
|
|
if funcs:
|
|
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
|
|
if tools:
|
|
args['tools'] = tools
|
|
args.setdefault('tool_choice', 'auto')
|
|
|
|
return args
|
|
|
|
async def invoke_llm(
|
|
self,
|
|
query: pipeline_query.Query,
|
|
model: requester.RuntimeLLMModel,
|
|
messages: typing.List[provider_message.Message],
|
|
funcs: typing.List[resource_tool.LLMTool] = None,
|
|
extra_args: dict[str, typing.Any] = {},
|
|
remove_think: bool = False,
|
|
) -> tuple[provider_message.Message, dict]:
|
|
"""Invoke LLM and return message with usage info."""
|
|
args = await self._build_completion_args(model, messages, funcs, extra_args, stream=False)
|
|
|
|
try:
|
|
response = await acompletion(**args)
|
|
|
|
message_data = response.choices[0].message.model_dump()
|
|
if 'role' not in message_data or message_data['role'] is None:
|
|
message_data['role'] = 'assistant'
|
|
|
|
content = message_data.get('content', '')
|
|
reasoning_content = message_data.get('reasoning_content', None)
|
|
message_data['content'] = self._process_thinking_content(content, reasoning_content, remove_think)
|
|
|
|
if 'reasoning_content' in message_data:
|
|
del message_data['reasoning_content']
|
|
|
|
message = provider_message.Message(**message_data)
|
|
usage_info = self._extract_usage(response)
|
|
|
|
return message, usage_info
|
|
|
|
except Exception as e:
|
|
self._handle_litellm_error(e)
|
|
|
|
async def invoke_llm_stream(
|
|
self,
|
|
query: pipeline_query.Query,
|
|
model: requester.RuntimeLLMModel,
|
|
messages: typing.List[provider_message.Message],
|
|
funcs: typing.List[resource_tool.LLMTool] = None,
|
|
extra_args: dict[str, typing.Any] = {},
|
|
remove_think: bool = False,
|
|
) -> provider_message.MessageChunk:
|
|
"""Invoke LLM streaming and yield chunks."""
|
|
args = await self._build_completion_args(model, messages, funcs, extra_args, stream=True)
|
|
|
|
chunk_idx = 0
|
|
role = 'assistant'
|
|
tool_call_state: dict[int, dict[str, str]] = {}
|
|
|
|
try:
|
|
response = await acompletion(**args)
|
|
async for chunk in response:
|
|
# Capture usage whenever a chunk carries it.
|
|
#
|
|
# Important: many OpenAI-compatible gateways (e.g. new-api) and
|
|
# providers send the final usage payload in a chunk that STILL
|
|
# contains a (empty-delta) choice, not an empty `choices` list.
|
|
# The previous implementation only captured usage when `choices`
|
|
# was empty, so streamed calls always recorded 0 tokens.
|
|
# We therefore capture usage independently of `choices`, and then
|
|
# fall through to also process any content this chunk may carry.
|
|
if getattr(chunk, 'usage', None):
|
|
usage_info = self._normalize_usage(chunk.usage)
|
|
if query is not None:
|
|
if query.variables is None:
|
|
query.variables = {}
|
|
query.variables['_stream_usage'] = usage_info
|
|
|
|
if not hasattr(chunk, 'choices') or not chunk.choices:
|
|
continue
|
|
|
|
choice = chunk.choices[0]
|
|
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
|
|
finish_reason = getattr(choice, 'finish_reason', None)
|
|
|
|
if 'role' in delta and delta['role']:
|
|
role = delta['role']
|
|
|
|
delta_content = delta.get('content', '')
|
|
reasoning_content = delta.get('reasoning_content', '')
|
|
|
|
# Handle reasoning_content based on remove_think flag
|
|
if reasoning_content:
|
|
if remove_think:
|
|
# Skip reasoning content when remove_think is True
|
|
chunk_idx += 1
|
|
continue
|
|
else:
|
|
# Use reasoning_content as the displayed content
|
|
delta_content = reasoning_content
|
|
|
|
tool_calls = self._normalize_stream_tool_calls(delta.get('tool_calls'), tool_call_state)
|
|
|
|
if chunk_idx == 0 and not delta_content and not tool_calls:
|
|
chunk_idx += 1
|
|
continue
|
|
|
|
chunk_data = {
|
|
'role': role,
|
|
'content': delta_content if delta_content else None,
|
|
'tool_calls': tool_calls,
|
|
'is_final': bool(finish_reason),
|
|
}
|
|
|
|
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
|
yield provider_message.MessageChunk(**chunk_data)
|
|
chunk_idx += 1
|
|
|
|
except Exception as e:
|
|
self._handle_litellm_error(e)
|
|
|
|
async def invoke_embedding(
|
|
self,
|
|
model: requester.RuntimeEmbeddingModel,
|
|
input_text: list[str],
|
|
extra_args: dict[str, typing.Any] = {},
|
|
) -> tuple[list[list[float]], dict]:
|
|
"""Invoke embedding and return vectors with usage info."""
|
|
model_name = self._build_litellm_model_name(model.model_entity.name)
|
|
api_key = model.provider.token_mgr.get_token()
|
|
|
|
args = {
|
|
'model': model_name,
|
|
'input': input_text,
|
|
'api_key': api_key,
|
|
}
|
|
self._build_common_args(args, include_retry_params=False)
|
|
|
|
if model.model_entity.extra_args:
|
|
args.update(model.model_entity.extra_args)
|
|
|
|
args.update(extra_args)
|
|
|
|
try:
|
|
response = await aembedding(**args)
|
|
|
|
# LiteLLM returns response.data entries either as objects with an
|
|
# `.embedding` attribute or as plain dicts (many OpenAI-compatible
|
|
# gateways, e.g. new-api, yield dict-shaped entries). Handle both.
|
|
embeddings = [d['embedding'] if isinstance(d, dict) else d.embedding for d in response.data]
|
|
usage_info = self._extract_usage(response)
|
|
|
|
return embeddings, usage_info
|
|
|
|
except Exception as e:
|
|
self._handle_litellm_error(e)
|
|
|
|
async def invoke_rerank(
|
|
self,
|
|
model: requester.RuntimeRerankModel,
|
|
query: str,
|
|
documents: typing.List[str],
|
|
extra_args: dict[str, typing.Any] = {},
|
|
) -> typing.List[dict]:
|
|
"""Invoke rerank and return relevance scores."""
|
|
model_name = self._build_litellm_model_name(model.model_entity.name)
|
|
api_key = model.provider.token_mgr.get_token()
|
|
|
|
top_n = min(len(documents), 64)
|
|
|
|
provider = self._get_custom_llm_provider()
|
|
|
|
try:
|
|
# LiteLLM's rerank API does not support the `openai` provider
|
|
# (litellm/rerank_api/main.py raises "Unsupported provider: openai").
|
|
# OpenAI-compatible gateways (newapi / one-api / vLLM / Xinference, etc.)
|
|
# expose the standard Jina/Cohere-style POST /v1/rerank endpoint, so
|
|
# call it directly over HTTP for openai-compatible (or unspecified) providers.
|
|
if provider in (None, '', 'openai'):
|
|
results = await self._invoke_rerank_openai_compatible(
|
|
model_name=model.model_entity.name,
|
|
query=query,
|
|
documents=documents,
|
|
api_key=api_key,
|
|
top_n=top_n,
|
|
extra_args={**(model.model_entity.extra_args or {}), **extra_args},
|
|
)
|
|
else:
|
|
args = {
|
|
'model': model_name,
|
|
'query': query,
|
|
'documents': documents,
|
|
'api_key': api_key,
|
|
'top_n': top_n,
|
|
}
|
|
self._build_common_args(args, include_retry_params=False)
|
|
|
|
if model.model_entity.extra_args:
|
|
args.update(model.model_entity.extra_args)
|
|
|
|
args.update(extra_args)
|
|
|
|
response = await arerank(**args)
|
|
|
|
results = []
|
|
for r in response.results:
|
|
results.append(
|
|
{
|
|
'index': r.get('index', 0),
|
|
'relevance_score': r.get('relevance_score', 0.0),
|
|
}
|
|
)
|
|
|
|
if results:
|
|
scores = [r['relevance_score'] for r in results]
|
|
min_score = min(scores)
|
|
max_score = max(scores)
|
|
if max_score - min_score > 1e-6:
|
|
for r in results:
|
|
r['relevance_score'] = (r['relevance_score'] - min_score) / (max_score - min_score)
|
|
|
|
return results
|
|
|
|
except errors.RequesterError:
|
|
raise
|
|
except Exception as e:
|
|
self._handle_litellm_error(e)
|
|
|
|
async def _invoke_rerank_openai_compatible(
|
|
self,
|
|
model_name: str,
|
|
query: str,
|
|
documents: typing.List[str],
|
|
api_key: str,
|
|
top_n: int,
|
|
extra_args: dict[str, typing.Any] = {},
|
|
) -> typing.List[dict]:
|
|
"""Call the standard Jina/Cohere-style POST /v1/rerank endpoint over HTTP.
|
|
|
|
Used for OpenAI-compatible gateways where litellm.arerank rejects the
|
|
`openai` provider. Returns the same shape as the litellm path:
|
|
a list of {'index': int, 'relevance_score': float}.
|
|
"""
|
|
import httpx
|
|
|
|
base_url = (self.requester_cfg.get('base_url') or '').rstrip('/')
|
|
if not base_url:
|
|
raise errors.RequesterError('Base URL required for rerank')
|
|
|
|
timeout = self.requester_cfg.get('timeout', 120)
|
|
|
|
headers = {'Content-Type': 'application/json'}
|
|
if api_key:
|
|
headers['Authorization'] = f'Bearer {api_key}'
|
|
|
|
payload: dict[str, typing.Any] = {
|
|
'model': model_name,
|
|
'query': query,
|
|
'documents': documents,
|
|
'top_n': top_n,
|
|
}
|
|
if extra_args:
|
|
payload.update(extra_args)
|
|
|
|
rerank_url = f'{base_url}/rerank'
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
resp = await client.post(rerank_url, headers=headers, json=payload)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
except httpx.HTTPStatusError as e:
|
|
body = ''
|
|
try:
|
|
body = e.response.text
|
|
except Exception:
|
|
pass
|
|
raise errors.RequesterError(f'rerank 请求失败 (HTTP {e.response.status_code}): {body or str(e)}')
|
|
except httpx.HTTPError as e:
|
|
raise errors.RequesterError(f'rerank 连接错误: {str(e)}')
|
|
|
|
raw_results = data.get('results', []) if isinstance(data, dict) else []
|
|
results = []
|
|
for r in raw_results:
|
|
results.append(
|
|
{
|
|
'index': r.get('index', 0),
|
|
'relevance_score': r.get('relevance_score', r.get('score', 0.0)) or 0.0,
|
|
}
|
|
)
|
|
|
|
return results
|
|
|
|
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
|
|
"""Scan models supported by the provider."""
|
|
import httpx
|
|
|
|
base_url = self.requester_cfg.get('base_url', '').rstrip('/')
|
|
timeout = self.requester_cfg.get('timeout', 120)
|
|
|
|
if not base_url:
|
|
raise errors.RequesterError('Base URL required for model scanning')
|
|
|
|
headers = {}
|
|
if api_key:
|
|
headers['Authorization'] = f'Bearer {api_key}'
|
|
|
|
models_url = f'{base_url}/models'
|
|
|
|
try:
|
|
async with httpx.AsyncClient(trust_env=True, timeout=timeout) as client:
|
|
response = await client.get(models_url, headers=headers)
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
|
|
models = []
|
|
for item in payload.get('data', []):
|
|
model_id = item.get('id')
|
|
if not model_id:
|
|
continue
|
|
|
|
models.append(self._enrich_scanned_model(model_id, item))
|
|
|
|
models.sort(key=lambda x: (x['type'] != 'llm', x['name'].lower()))
|
|
|
|
return {'models': models}
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
raise errors.RequesterError(f'Model scan failed: {e.response.status_code}')
|
|
except httpx.TimeoutException:
|
|
raise errors.RequesterError('Model scan timeout')
|
|
except Exception as e:
|
|
raise errors.RequesterError(f'Model scan error: {str(e)}')
|