mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 12:05:54 +00:00
Compare commits
4 Commits
feat/agent
...
feat/litel
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8dd16aac51 | ||
|
|
d170bdd343 | ||
|
|
b33d05f99a | ||
|
|
de61b5d368 |
@@ -77,6 +77,7 @@ dependencies = [
|
|||||||
"pymilvus>=2.6.4",
|
"pymilvus>=2.6.4",
|
||||||
"pgvector>=0.4.1",
|
"pgvector>=0.4.1",
|
||||||
"botocore>=1.42.39",
|
"botocore>=1.42.39",
|
||||||
|
"litellm>=1.0.0",
|
||||||
]
|
]
|
||||||
keywords = [
|
keywords = [
|
||||||
"bot",
|
"bot",
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import sqlalchemy
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from . import requester
|
from . import requester
|
||||||
|
from .requesters import litellmchat
|
||||||
from ...core import app
|
from ...core import app
|
||||||
from ...discover import engine
|
from ...discover import engine
|
||||||
from . import token
|
from . import token
|
||||||
@@ -42,6 +43,13 @@ class ModelManager:
|
|||||||
|
|
||||||
requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {}
|
requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {}
|
||||||
for component in self.requester_components:
|
for component in self.requester_components:
|
||||||
|
# Skip components that use litellm_provider (they will use litellmchat.py instead)
|
||||||
|
if component.spec.get('litellm_provider'):
|
||||||
|
self.ap.logger.debug(
|
||||||
|
f'Skipping Python class loading for {component.metadata.name} '
|
||||||
|
f'(uses litellm_provider={component.spec.get("litellm_provider")})'
|
||||||
|
)
|
||||||
|
continue
|
||||||
requester_dict[component.metadata.name] = component.get_python_component_class()
|
requester_dict[component.metadata.name] = component.get_python_component_class()
|
||||||
|
|
||||||
self.requester_dict = requester_dict
|
self.requester_dict = requester_dict
|
||||||
@@ -260,13 +268,34 @@ class ModelManager:
|
|||||||
else:
|
else:
|
||||||
provider_entity = provider_info
|
provider_entity = provider_info
|
||||||
|
|
||||||
if provider_entity.requester not in self.requester_dict:
|
# Get requester manifest to check for litellm_provider
|
||||||
raise provider_errors.RequesterNotFoundError(provider_entity.requester)
|
requester_manifest = self.get_available_requester_manifest_by_name(provider_entity.requester)
|
||||||
|
|
||||||
|
# Build config from base_url
|
||||||
|
config = {'base_url': provider_entity.base_url}
|
||||||
|
|
||||||
|
# Check if requester manifest specifies litellm_provider
|
||||||
|
if requester_manifest and requester_manifest.spec.get('litellm_provider'):
|
||||||
|
# Use unified LiteLLMRequester with provider prefix
|
||||||
|
# Map litellm_provider (YAML spec) to custom_llm_provider (config)
|
||||||
|
config['custom_llm_provider'] = requester_manifest.spec['litellm_provider']
|
||||||
|
requester_inst = litellmchat.LiteLLMRequester(
|
||||||
|
ap=self.ap,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
self.ap.logger.debug(
|
||||||
|
f'Using LiteLLMRequester for {provider_entity.requester} '
|
||||||
|
f'with custom_llm_provider={config["custom_llm_provider"]}'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use original requester class (for backward compatibility)
|
||||||
|
if provider_entity.requester not in self.requester_dict:
|
||||||
|
raise provider_errors.RequesterNotFoundError(provider_entity.requester)
|
||||||
|
requester_inst = self.requester_dict[provider_entity.requester](
|
||||||
|
ap=self.ap,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
requester_inst = self.requester_dict[provider_entity.requester](
|
|
||||||
ap=self.ap,
|
|
||||||
config={'base_url': provider_entity.base_url},
|
|
||||||
)
|
|
||||||
await requester_inst.initialize()
|
await requester_inst.initialize()
|
||||||
|
|
||||||
token_mgr = token.TokenManager(name=provider_entity.uuid, tokens=provider_entity.api_keys or [])
|
token_mgr = token.TokenManager(name=provider_entity.uuid, tokens=provider_entity.api_keys or [])
|
||||||
|
|||||||
@@ -67,8 +67,8 @@ class RuntimeProvider:
|
|||||||
if isinstance(result, tuple):
|
if isinstance(result, tuple):
|
||||||
msg, usage_info = result
|
msg, usage_info = result
|
||||||
if usage_info:
|
if usage_info:
|
||||||
input_tokens = usage_info.get('input_tokens', 0)
|
input_tokens = usage_info.get('prompt_tokens', 0)
|
||||||
output_tokens = usage_info.get('output_tokens', 0)
|
output_tokens = usage_info.get('completion_tokens', 0)
|
||||||
return msg
|
return msg
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
@@ -128,7 +128,6 @@ class RuntimeProvider:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
status = 'success'
|
status = 'success'
|
||||||
error_message = None
|
error_message = None
|
||||||
# Note: Stream doesn't easily provide token counts, set to 0
|
|
||||||
input_tokens = 0
|
input_tokens = 0
|
||||||
output_tokens = 0
|
output_tokens = 0
|
||||||
|
|
||||||
@@ -143,6 +142,15 @@ class RuntimeProvider:
|
|||||||
remove_think=remove_think,
|
remove_think=remove_think,
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
# Extract usage from stream if available (stored by LiteLLM requester)
|
||||||
|
if query:
|
||||||
|
if query.variables is None:
|
||||||
|
query.variables = {}
|
||||||
|
if '_stream_usage' in query.variables:
|
||||||
|
usage_info = query.variables['_stream_usage']
|
||||||
|
input_tokens = usage_info.get('prompt_tokens', 0)
|
||||||
|
output_tokens = usage_info.get('completion_tokens', 0)
|
||||||
|
del query.variables['_stream_usage']
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
status = 'error'
|
status = 'error'
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
|
|||||||
397
src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py
Normal file
397
src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
"""LiteLLM unified requester for chat, embedding, and rerank."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import acompletion, aembedding, arerank
|
||||||
|
|
||||||
|
from .. import errors, requester
|
||||||
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||||
|
"""LiteLLM unified API requester supporting chat, embedding, and rerank."""
|
||||||
|
|
||||||
|
default_config: dict[str, typing.Any] = {
|
||||||
|
'base_url': '',
|
||||||
|
'timeout': 120,
|
||||||
|
'custom_llm_provider': '',
|
||||||
|
'drop_params': False,
|
||||||
|
'num_retries': 0,
|
||||||
|
'api_version': '',
|
||||||
|
}
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""Initialize LiteLLM client settings."""
|
||||||
|
# LiteLLM doesn't require explicit client initialization
|
||||||
|
# Configuration is passed per-request via litellm params
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _build_litellm_model_name(self, model_name: str, custom_llm_provider: str | None = None) -> str:
|
||||||
|
"""Build LiteLLM model name with provider prefix if needed."""
|
||||||
|
provider = custom_llm_provider or self.requester_cfg.get('custom_llm_provider', '')
|
||||||
|
if provider:
|
||||||
|
# LiteLLM format: provider/model_name
|
||||||
|
return f'{provider}/{model_name}'
|
||||||
|
# If no custom provider, assume model_name already includes prefix or is OpenAI-compatible
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
def _convert_messages(self, messages: typing.List[provider_message.Message]) -> list[dict]:
|
||||||
|
"""Convert LangBot messages to LiteLLM/OpenAI format."""
|
||||||
|
req_messages = []
|
||||||
|
for m in messages:
|
||||||
|
msg_dict = m.dict(exclude_none=True)
|
||||||
|
content = msg_dict.get('content')
|
||||||
|
|
||||||
|
if isinstance(content, list):
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and part.get('type') == 'image_base64':
|
||||||
|
part['image_url'] = {'url': part['image_base64']}
|
||||||
|
part['type'] = 'image_url'
|
||||||
|
del part['image_base64']
|
||||||
|
|
||||||
|
req_messages.append(msg_dict)
|
||||||
|
|
||||||
|
return req_messages
|
||||||
|
|
||||||
|
def _process_thinking_content(self, content: str, reasoning_content: str | None, remove_think: bool) -> str:
|
||||||
|
"""Process thinking/reasoning content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The main content from response
|
||||||
|
reasoning_content: Separate reasoning content from model
|
||||||
|
remove_think: If True, remove thinking markers; if False, preserve them
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed content string
|
||||||
|
"""
|
||||||
|
# Extract and handle thinking tags
|
||||||
|
if content and 'CRETIRE_REASONING_BEGINk' in content and 'CRETIRE_REASONING_ENDk' in content:
|
||||||
|
import re
|
||||||
|
|
||||||
|
think_pattern = r'CRETIRE_REASONING_BEGINk(.*?)CRETIRE_REASONING_ENDk'
|
||||||
|
|
||||||
|
if remove_think:
|
||||||
|
# Remove thinking tags and their content from output
|
||||||
|
content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip()
|
||||||
|
# else: preserve thinking content as-is
|
||||||
|
|
||||||
|
# Handle separate reasoning_content field
|
||||||
|
# Currently we don't include reasoning_content in user-facing output regardless of remove_think
|
||||||
|
# because it's typically internal model reasoning, not user-visible thinking
|
||||||
|
return content or ''
|
||||||
|
|
||||||
|
def _extract_usage(self, response) -> dict:
|
||||||
|
"""Extract usage info from LiteLLM response."""
|
||||||
|
usage = response.usage
|
||||||
|
return {
|
||||||
|
'prompt_tokens': usage.prompt_tokens or 0,
|
||||||
|
'completion_tokens': usage.completion_tokens or 0,
|
||||||
|
'total_tokens': usage.total_tokens or 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_common_args(self, args: dict, include_retry_params: bool = True) -> dict:
|
||||||
|
"""Apply common requester config to args dict."""
|
||||||
|
if self.requester_cfg.get('base_url'):
|
||||||
|
args['api_base'] = self.requester_cfg['base_url']
|
||||||
|
if self.requester_cfg.get('timeout'):
|
||||||
|
args['timeout'] = self.requester_cfg['timeout']
|
||||||
|
if include_retry_params:
|
||||||
|
if self.requester_cfg.get('drop_params'):
|
||||||
|
args['drop_params'] = self.requester_cfg['drop_params']
|
||||||
|
if self.requester_cfg.get('num_retries'):
|
||||||
|
args['num_retries'] = self.requester_cfg['num_retries']
|
||||||
|
if self.requester_cfg.get('api_version'):
|
||||||
|
args['api_version'] = self.requester_cfg['api_version']
|
||||||
|
return args
|
||||||
|
|
||||||
|
def _handle_litellm_error(self, e: Exception) -> None:
|
||||||
|
"""Convert LiteLLM exceptions to RequesterError. Never returns, always raises."""
|
||||||
|
# Check more specific exceptions first (they inherit from base exceptions)
|
||||||
|
if isinstance(e, litellm.ContextWindowExceededError):
|
||||||
|
raise errors.RequesterError(f'上下文长度超限: {str(e)}')
|
||||||
|
if isinstance(e, litellm.BadRequestError):
|
||||||
|
raise errors.RequesterError(f'请求参数错误: {str(e)}')
|
||||||
|
if isinstance(e, litellm.AuthenticationError):
|
||||||
|
raise errors.RequesterError(f'API key 无效: {str(e)}')
|
||||||
|
if isinstance(e, litellm.NotFoundError):
|
||||||
|
raise errors.RequesterError(f'模型或路径无效: {str(e)}')
|
||||||
|
if isinstance(e, litellm.RateLimitError):
|
||||||
|
raise errors.RequesterError(f'请求过于频繁或余额不足: {str(e)}')
|
||||||
|
if isinstance(e, litellm.Timeout):
|
||||||
|
raise errors.RequesterError(f'请求超时: {str(e)}')
|
||||||
|
if isinstance(e, litellm.APIConnectionError):
|
||||||
|
raise errors.RequesterError(f'连接错误: {str(e)}')
|
||||||
|
if isinstance(e, litellm.APIError):
|
||||||
|
raise errors.RequesterError(f'API 错误: {str(e)}')
|
||||||
|
raise errors.RequesterError(f'未知错误: {str(e)}')
|
||||||
|
|
||||||
|
async def _build_completion_args(
|
||||||
|
self,
|
||||||
|
model: requester.RuntimeLLMModel,
|
||||||
|
messages: typing.List[provider_message.Message],
|
||||||
|
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
stream: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""Build common completion arguments for invoke_llm and invoke_llm_stream."""
|
||||||
|
req_messages = self._convert_messages(messages)
|
||||||
|
model_name = self._build_litellm_model_name(model.model_entity.name)
|
||||||
|
api_key = model.provider.token_mgr.get_token()
|
||||||
|
|
||||||
|
args = {
|
||||||
|
'model': model_name,
|
||||||
|
'messages': req_messages,
|
||||||
|
'api_key': api_key,
|
||||||
|
}
|
||||||
|
if stream:
|
||||||
|
args['stream'] = True
|
||||||
|
args['stream_options'] = {'include_usage': True}
|
||||||
|
self._build_common_args(args)
|
||||||
|
args.update(extra_args)
|
||||||
|
|
||||||
|
if funcs:
|
||||||
|
tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
|
||||||
|
if tools:
|
||||||
|
args['tools'] = tools
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
async def invoke_llm(
|
||||||
|
self,
|
||||||
|
query: pipeline_query.Query,
|
||||||
|
model: requester.RuntimeLLMModel,
|
||||||
|
messages: typing.List[provider_message.Message],
|
||||||
|
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
remove_think: bool = False,
|
||||||
|
) -> tuple[provider_message.Message, dict]:
|
||||||
|
"""Invoke LLM and return message with usage info."""
|
||||||
|
args = await self._build_completion_args(model, messages, funcs, extra_args, stream=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await acompletion(**args)
|
||||||
|
|
||||||
|
message_data = response.choices[0].message.model_dump()
|
||||||
|
if 'role' not in message_data or message_data['role'] is None:
|
||||||
|
message_data['role'] = 'assistant'
|
||||||
|
|
||||||
|
content = message_data.get('content', '')
|
||||||
|
reasoning_content = message_data.get('reasoning_content', None)
|
||||||
|
message_data['content'] = self._process_thinking_content(content, reasoning_content, remove_think)
|
||||||
|
|
||||||
|
if 'reasoning_content' in message_data:
|
||||||
|
del message_data['reasoning_content']
|
||||||
|
|
||||||
|
message = provider_message.Message(**message_data)
|
||||||
|
usage_info = self._extract_usage(response)
|
||||||
|
|
||||||
|
return message, usage_info
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._handle_litellm_error(e)
|
||||||
|
|
||||||
|
async def invoke_llm_stream(
|
||||||
|
self,
|
||||||
|
query: pipeline_query.Query,
|
||||||
|
model: requester.RuntimeLLMModel,
|
||||||
|
messages: typing.List[provider_message.Message],
|
||||||
|
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
remove_think: bool = False,
|
||||||
|
) -> provider_message.MessageChunk:
|
||||||
|
"""Invoke LLM streaming and yield chunks."""
|
||||||
|
args = await self._build_completion_args(model, messages, funcs, extra_args, stream=True)
|
||||||
|
|
||||||
|
chunk_idx = 0
|
||||||
|
role = 'assistant'
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await acompletion(**args)
|
||||||
|
async for chunk in response:
|
||||||
|
# Check for usage chunk (final chunk with stream_options include_usage)
|
||||||
|
if hasattr(chunk, 'usage') and chunk.usage and (not hasattr(chunk, 'choices') or not chunk.choices):
|
||||||
|
usage_info = {
|
||||||
|
'prompt_tokens': chunk.usage.prompt_tokens or 0,
|
||||||
|
'completion_tokens': chunk.usage.completion_tokens or 0,
|
||||||
|
'total_tokens': chunk.usage.total_tokens or 0,
|
||||||
|
}
|
||||||
|
if query:
|
||||||
|
if query.variables is None:
|
||||||
|
query.variables = {}
|
||||||
|
query.variables['_stream_usage'] = usage_info
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not hasattr(chunk, 'choices') or not chunk.choices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
|
||||||
|
finish_reason = getattr(choice, 'finish_reason', None)
|
||||||
|
|
||||||
|
if 'role' in delta and delta['role']:
|
||||||
|
role = delta['role']
|
||||||
|
|
||||||
|
delta_content = delta.get('content', '')
|
||||||
|
reasoning_content = delta.get('reasoning_content', '')
|
||||||
|
|
||||||
|
if reasoning_content:
|
||||||
|
chunk_idx += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if chunk_idx == 0 and not delta_content and not delta.get('tool_calls'):
|
||||||
|
chunk_idx += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk_data = {
|
||||||
|
'role': role,
|
||||||
|
'content': delta_content if delta_content else None,
|
||||||
|
'tool_calls': delta.get('tool_calls'),
|
||||||
|
'is_final': bool(finish_reason),
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
||||||
|
yield provider_message.MessageChunk(**chunk_data)
|
||||||
|
chunk_idx += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._handle_litellm_error(e)
|
||||||
|
|
||||||
|
async def invoke_embedding(
|
||||||
|
self,
|
||||||
|
model: requester.RuntimeEmbeddingModel,
|
||||||
|
input_text: list[str],
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
) -> tuple[list[list[float]], dict]:
|
||||||
|
"""Invoke embedding and return vectors with usage info."""
|
||||||
|
model_name = self._build_litellm_model_name(model.model_entity.name)
|
||||||
|
api_key = model.provider.token_mgr.get_token()
|
||||||
|
|
||||||
|
args = {
|
||||||
|
'model': model_name,
|
||||||
|
'input': input_text,
|
||||||
|
'api_key': api_key,
|
||||||
|
}
|
||||||
|
self._build_common_args(args, include_retry_params=False)
|
||||||
|
|
||||||
|
if model.model_entity.extra_args:
|
||||||
|
args.update(model.model_entity.extra_args)
|
||||||
|
|
||||||
|
args.update(extra_args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await aembedding(**args)
|
||||||
|
|
||||||
|
embeddings = [d.embedding for d in response.data]
|
||||||
|
usage_info = self._extract_usage(response)
|
||||||
|
|
||||||
|
return embeddings, usage_info
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._handle_litellm_error(e)
|
||||||
|
|
||||||
|
async def invoke_rerank(
|
||||||
|
self,
|
||||||
|
model: requester.RuntimeRerankModel,
|
||||||
|
query: str,
|
||||||
|
documents: typing.List[str],
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
) -> typing.List[dict]:
|
||||||
|
"""Invoke rerank and return relevance scores."""
|
||||||
|
model_name = self._build_litellm_model_name(model.model_entity.name)
|
||||||
|
api_key = model.provider.token_mgr.get_token()
|
||||||
|
|
||||||
|
args = {
|
||||||
|
'model': model_name,
|
||||||
|
'query': query,
|
||||||
|
'documents': documents,
|
||||||
|
'api_key': api_key,
|
||||||
|
'top_n': min(len(documents), 64),
|
||||||
|
}
|
||||||
|
self._build_common_args(args, include_retry_params=False)
|
||||||
|
|
||||||
|
if model.model_entity.extra_args:
|
||||||
|
args.update(model.model_entity.extra_args)
|
||||||
|
|
||||||
|
args.update(extra_args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await arerank(**args)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for r in response.results:
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
'index': r.get('index', 0),
|
||||||
|
'relevance_score': r.get('relevance_score', 0.0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if results:
|
||||||
|
scores = [r['relevance_score'] for r in results]
|
||||||
|
min_score = min(scores)
|
||||||
|
max_score = max(scores)
|
||||||
|
if max_score - min_score > 1e-6:
|
||||||
|
for r in results:
|
||||||
|
r['relevance_score'] = (r['relevance_score'] - min_score) / (max_score - min_score)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._handle_litellm_error(e)
|
||||||
|
|
||||||
|
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
|
||||||
|
"""Scan models supported by the provider."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
base_url = self.requester_cfg.get('base_url', '').rstrip('/')
|
||||||
|
timeout = self.requester_cfg.get('timeout', 120)
|
||||||
|
|
||||||
|
if not base_url:
|
||||||
|
raise errors.RequesterError('Base URL required for model scanning')
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if api_key:
|
||||||
|
headers['Authorization'] = f'Bearer {api_key}'
|
||||||
|
|
||||||
|
models_url = f'{base_url}/models'
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(trust_env=True, timeout=timeout) as client:
|
||||||
|
response = await client.get(models_url, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
payload = response.json()
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for item in payload.get('data', []):
|
||||||
|
model_id = item.get('id')
|
||||||
|
if not model_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Infer model type
|
||||||
|
normalized_id = (model_id or '').lower()
|
||||||
|
embedding_keywords = ('embedding', 'embed', 'bge-', 'e5-', 'm3e', 'gte-', 'text-embedding')
|
||||||
|
model_type = 'embedding' if any(kw in normalized_id for kw in embedding_keywords) else 'llm'
|
||||||
|
|
||||||
|
models.append(
|
||||||
|
{
|
||||||
|
'id': model_id,
|
||||||
|
'name': model_id,
|
||||||
|
'type': model_type,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
models.sort(key=lambda x: (x['type'] != 'llm', x['name'].lower()))
|
||||||
|
|
||||||
|
return {'models': models}
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise errors.RequesterError(f'Model scan failed: {e.response.status_code}')
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise errors.RequesterError('Model scan timeout')
|
||||||
|
except Exception as e:
|
||||||
|
raise errors.RequesterError(f'Model scan error: {str(e)}')
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: litellm-chat
|
||||||
|
label:
|
||||||
|
en_US: LiteLLM (Unified)
|
||||||
|
zh_Hans: LiteLLM (统一请求器)
|
||||||
|
icon: litellm.svg
|
||||||
|
spec:
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
- name: custom_llm_provider
|
||||||
|
label:
|
||||||
|
en_US: Custom Provider
|
||||||
|
zh_Hans: 自定义 Provider
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
description:
|
||||||
|
en_US: Force provider type (e.g., anthropic, openai, gemini)
|
||||||
|
zh_Hans: 强制指定 provider 类型(如 anthropic, openai, gemini)
|
||||||
|
- name: drop_params
|
||||||
|
label:
|
||||||
|
en_US: Drop Unsupported Params
|
||||||
|
zh_Hans: 丢弃不支持参数
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
- name: num_retries
|
||||||
|
label:
|
||||||
|
en_US: Number of Retries
|
||||||
|
zh_Hans: 重试次数
|
||||||
|
type: integer
|
||||||
|
required: false
|
||||||
|
default: 0
|
||||||
|
- name: api_version
|
||||||
|
label:
|
||||||
|
en_US: API Version
|
||||||
|
zh_Hans: API 版本
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
support_type:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
provider_category: unified
|
||||||
|
execution:
|
||||||
|
python:
|
||||||
|
path: ./litellmchat.py
|
||||||
|
attr: LiteLLMRequester
|
||||||
@@ -57,41 +57,6 @@ class ToolManager:
|
|||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
async def generate_tools_for_anthropic(self, use_funcs: list[resource_tool.LLMTool]) -> list:
|
|
||||||
"""为anthropic生成函数列表
|
|
||||||
|
|
||||||
e.g.
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "get_stock_price",
|
|
||||||
"description": "Get the current stock price for a given ticker symbol.",
|
|
||||||
"input_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"ticker": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The stock ticker symbol, e.g. AAPL for Apple Inc."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["ticker"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
"""
|
|
||||||
|
|
||||||
tools = []
|
|
||||||
|
|
||||||
for function in use_funcs:
|
|
||||||
function_schema = {
|
|
||||||
'name': function.name,
|
|
||||||
'description': function.description,
|
|
||||||
'input_schema': function.parameters,
|
|
||||||
}
|
|
||||||
tools.append(function_schema)
|
|
||||||
|
|
||||||
return tools
|
|
||||||
|
|
||||||
async def execute_func_call(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any:
|
async def execute_func_call(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any:
|
||||||
"""执行函数调用"""
|
"""执行函数调用"""
|
||||||
|
|
||||||
|
|||||||
1
tests/unit_tests/provider/__init__.py
Normal file
1
tests/unit_tests/provider/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Provider requester tests"""
|
||||||
633
tests/unit_tests/provider/test_litellmchat.py
Normal file
633
tests/unit_tests/provider/test_litellmchat.py
Normal file
@@ -0,0 +1,633 @@
|
|||||||
|
"""
|
||||||
|
Tests for LiteLLMRequester - unified requester for chat, embedding, and rerank.
|
||||||
|
|
||||||
|
These tests verify:
|
||||||
|
- Parameter building and LiteLLM API calls
|
||||||
|
- Response processing and usage extraction
|
||||||
|
- Error handling and exception translation
|
||||||
|
- Model name building with provider prefix
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, AsyncMock, patch
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from langbot.pkg.provider.modelmgr.requesters import litellmchat
|
||||||
|
from langbot.pkg.provider.modelmgr import errors
|
||||||
|
|
||||||
|
|
||||||
|
class MockRuntimeModel:
|
||||||
|
"""Mock RuntimeLLMModel for testing"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = 'gpt-4o', api_key: str = 'test-key'):
|
||||||
|
self.model_entity = Mock()
|
||||||
|
self.model_entity.name = model_name
|
||||||
|
self.model_entity.extra_args = {}
|
||||||
|
self.provider = Mock()
|
||||||
|
self.provider.token_mgr = Mock()
|
||||||
|
self.provider.token_mgr.get_token = Mock(return_value=api_key)
|
||||||
|
|
||||||
|
|
||||||
|
class MockRuntimeEmbeddingModel:
|
||||||
|
"""Mock RuntimeEmbeddingModel for testing"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = 'text-embedding-3-small', api_key: str = 'test-key'):
|
||||||
|
self.model_entity = Mock()
|
||||||
|
self.model_entity.name = model_name
|
||||||
|
self.model_entity.extra_args = {}
|
||||||
|
self.provider = Mock()
|
||||||
|
self.provider.token_mgr = Mock()
|
||||||
|
self.provider.token_mgr.get_token = Mock(return_value=api_key)
|
||||||
|
|
||||||
|
|
||||||
|
class MockRuntimeRerankModel:
|
||||||
|
"""Mock RuntimeRerankModel for testing"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = 'cohere/rerank-english-v3.0', api_key: str = 'test-key'):
|
||||||
|
self.model_entity = Mock()
|
||||||
|
self.model_entity.name = model_name
|
||||||
|
self.model_entity.extra_args = {}
|
||||||
|
self.provider = Mock()
|
||||||
|
self.provider.token_mgr = Mock()
|
||||||
|
self.provider.token_mgr.get_token = Mock(return_value=api_key)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildLiteLLMModelName:
|
||||||
|
"""Test _build_litellm_model_name method"""
|
||||||
|
|
||||||
|
def test_no_provider_prefix(self):
|
||||||
|
"""Test model name without provider prefix"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': ''})
|
||||||
|
result = requester._build_litellm_model_name('gpt-4o')
|
||||||
|
assert result == 'gpt-4o'
|
||||||
|
|
||||||
|
def test_with_provider_prefix(self):
|
||||||
|
"""Test model name with provider prefix"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'})
|
||||||
|
result = requester._build_litellm_model_name('gpt-4o')
|
||||||
|
assert result == 'openai/gpt-4o'
|
||||||
|
|
||||||
|
def test_override_provider(self):
|
||||||
|
"""Test override provider via parameter"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'})
|
||||||
|
result = requester._build_litellm_model_name('claude-3', custom_llm_provider='anthropic')
|
||||||
|
assert result == 'anthropic/claude-3'
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractUsage:
|
||||||
|
"""Test _extract_usage method"""
|
||||||
|
|
||||||
|
def test_extract_usage_with_data(self):
|
||||||
|
"""Test extraction with valid usage data"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
response = Mock()
|
||||||
|
response.usage = Mock()
|
||||||
|
response.usage.prompt_tokens = 100
|
||||||
|
response.usage.completion_tokens = 50
|
||||||
|
response.usage.total_tokens = 150
|
||||||
|
|
||||||
|
result = requester._extract_usage(response)
|
||||||
|
|
||||||
|
assert result['prompt_tokens'] == 100
|
||||||
|
assert result['completion_tokens'] == 50
|
||||||
|
assert result['total_tokens'] == 150
|
||||||
|
|
||||||
|
def test_extract_usage_with_zero_values(self):
|
||||||
|
"""Test extraction when values are 0"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
response = Mock()
|
||||||
|
response.usage = Mock()
|
||||||
|
response.usage.prompt_tokens = 0
|
||||||
|
response.usage.completion_tokens = 0
|
||||||
|
response.usage.total_tokens = 0
|
||||||
|
|
||||||
|
result = requester._extract_usage(response)
|
||||||
|
|
||||||
|
assert result['prompt_tokens'] == 0
|
||||||
|
assert result['completion_tokens'] == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessThinkingContent:
|
||||||
|
"""Test _process_thinking_content method"""
|
||||||
|
|
||||||
|
def test_no_thinking_markers(self):
|
||||||
|
"""Test content without thinking markers"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
result = requester._process_thinking_content('Hello world', None, remove_think=True)
|
||||||
|
assert result == 'Hello world'
|
||||||
|
|
||||||
|
def test_remove_thinking_markers(self):
|
||||||
|
"""Test removing thinking markers when remove_think=True"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
content = 'CRETIRE_REASONING_BEGINkLet me think...CRETIRE_REASONING_ENDk The answer is 42.'
|
||||||
|
result = requester._process_thinking_content(content, None, remove_think=True)
|
||||||
|
assert result == 'The answer is 42.'
|
||||||
|
|
||||||
|
def test_preserve_thinking_markers(self):
|
||||||
|
"""Test preserving thinking markers when remove_think=False"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
content = 'CRETIRE_REASONING_BEGINkLet me think...CRETIRE_REASONING_ENDk The answer is 42.'
|
||||||
|
result = requester._process_thinking_content(content, None, remove_think=False)
|
||||||
|
assert 'CRETIRE_REASONING_BEGINk' in result
|
||||||
|
assert 'The answer is 42.' in result
|
||||||
|
|
||||||
|
def test_empty_content(self):
|
||||||
|
"""Test empty content"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
result = requester._process_thinking_content('', None, remove_think=True)
|
||||||
|
assert result == ''
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildCommonArgs:
|
||||||
|
"""Test _build_common_args method"""
|
||||||
|
|
||||||
|
def test_build_args_with_all_params(self):
|
||||||
|
"""Test building args with all config params"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(
|
||||||
|
ap=Mock(),
|
||||||
|
config={
|
||||||
|
'base_url': 'https://api.openai.com/v1',
|
||||||
|
'timeout': 60,
|
||||||
|
'drop_params': True,
|
||||||
|
'num_retries': 3,
|
||||||
|
'api_version': '2024-01-01',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
args = {}
|
||||||
|
requester._build_common_args(args)
|
||||||
|
|
||||||
|
assert args['api_base'] == 'https://api.openai.com/v1'
|
||||||
|
assert args['timeout'] == 60
|
||||||
|
assert args['drop_params'] == True
|
||||||
|
assert args['num_retries'] == 3
|
||||||
|
assert args['api_version'] == '2024-01-01'
|
||||||
|
|
||||||
|
def test_build_args_without_retry_params(self):
|
||||||
|
"""Test building args without retry params for embedding/rerank"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(
|
||||||
|
ap=Mock(),
|
||||||
|
config={
|
||||||
|
'base_url': 'https://api.openai.com/v1',
|
||||||
|
'timeout': 60,
|
||||||
|
'num_retries': 3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
args = {}
|
||||||
|
requester._build_common_args(args, include_retry_params=False)
|
||||||
|
|
||||||
|
assert args['api_base'] == 'https://api.openai.com/v1'
|
||||||
|
assert args['timeout'] == 60
|
||||||
|
assert 'num_retries' not in args
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandleLiteLLMError:
|
||||||
|
"""Test _handle_litellm_error method"""
|
||||||
|
|
||||||
|
def test_bad_request_error(self):
|
||||||
|
"""Test BadRequestError translation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
# Create proper LiteLLM exception with required args
|
||||||
|
error = litellm.BadRequestError(message='test error', model='gpt-4o', llm_provider='openai')
|
||||||
|
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
requester._handle_litellm_error(error)
|
||||||
|
|
||||||
|
assert '请求参数错误' in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_authentication_error(self):
|
||||||
|
"""Test AuthenticationError translation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
error = litellm.AuthenticationError(message='invalid key', model='gpt-4o', llm_provider='openai')
|
||||||
|
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
requester._handle_litellm_error(error)
|
||||||
|
|
||||||
|
assert 'API key 无效' in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_rate_limit_error(self):
|
||||||
|
"""Test RateLimitError translation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
error = litellm.RateLimitError(message='rate limited', model='gpt-4o', llm_provider='openai')
|
||||||
|
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
requester._handle_litellm_error(error)
|
||||||
|
|
||||||
|
assert '请求过于频繁' in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_timeout_error(self):
|
||||||
|
"""Test Timeout translation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
error = litellm.Timeout(message='timeout', model='gpt-4o', llm_provider='openai')
|
||||||
|
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
requester._handle_litellm_error(error)
|
||||||
|
|
||||||
|
assert '请求超时' in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_context_window_error(self):
|
||||||
|
"""Test ContextWindowExceededError translation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
error = litellm.ContextWindowExceededError(message='context too long', model='gpt-4o', llm_provider='openai')
|
||||||
|
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
requester._handle_litellm_error(error)
|
||||||
|
|
||||||
|
assert '上下文长度超限' in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_unknown_error(self):
|
||||||
|
"""Test unknown error translation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
requester._handle_litellm_error(Exception('unknown'))
|
||||||
|
|
||||||
|
assert '未知错误' in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestInvokeLLM:
|
||||||
|
"""Test invoke_llm method"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_llm_basic(self):
|
||||||
|
"""Test basic LLM invocation"""
|
||||||
|
mock_ap = Mock()
|
||||||
|
mock_ap.tool_mgr = Mock()
|
||||||
|
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
requester = litellmchat.LiteLLMRequester(
|
||||||
|
ap=mock_ap,
|
||||||
|
config={
|
||||||
|
'base_url': 'https://api.openai.com/v1',
|
||||||
|
'timeout': 60,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
model = MockRuntimeModel('gpt-4o', 'test-api-key')
|
||||||
|
|
||||||
|
# Mock LiteLLM response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices = [Mock()]
|
||||||
|
mock_response.choices[0].message = Mock()
|
||||||
|
mock_response.choices[0].message.model_dump = Mock(
|
||||||
|
return_value={
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': 'Hello! How can I help you?',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response.usage = Mock()
|
||||||
|
mock_response.usage.prompt_tokens = 10
|
||||||
|
mock_response.usage.completion_tokens = 20
|
||||||
|
mock_response.usage.total_tokens = 30
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
|
||||||
|
messages = [provider_message.Message(role='user', content='Hello')]
|
||||||
|
|
||||||
|
# Patch acompletion at the import location
|
||||||
|
with patch.object(litellmchat, 'acompletion', new_callable=AsyncMock, return_value=mock_response):
|
||||||
|
result_msg, usage = await requester.invoke_llm(
|
||||||
|
query=None,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result_msg.role == 'assistant'
|
||||||
|
assert result_msg.content == 'Hello! How can I help you?'
|
||||||
|
assert usage['prompt_tokens'] == 10
|
||||||
|
assert usage['completion_tokens'] == 20
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_llm_with_tools(self):
|
||||||
|
"""Test LLM invocation with function calling"""
|
||||||
|
mock_ap = Mock()
|
||||||
|
mock_ap.tool_mgr = Mock()
|
||||||
|
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(
|
||||||
|
return_value=[{'type': 'function', 'function': {'name': 'get_weather'}}]
|
||||||
|
)
|
||||||
|
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={})
|
||||||
|
|
||||||
|
model = MockRuntimeModel('gpt-4o', 'test-api-key')
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices = [Mock()]
|
||||||
|
mock_response.choices[0].message = Mock()
|
||||||
|
mock_response.choices[0].message.model_dump = Mock(
|
||||||
|
return_value={
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': None,
|
||||||
|
'tool_calls': [
|
||||||
|
{'id': 'call_123', 'type': 'function', 'function': {'name': 'get_weather', 'arguments': '{}'}}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response.usage = Mock()
|
||||||
|
mock_response.usage.prompt_tokens = 15
|
||||||
|
mock_response.usage.completion_tokens = 10
|
||||||
|
mock_response.usage.total_tokens = 25
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
|
||||||
|
messages = [provider_message.Message(role='user', content='What is the weather?')]
|
||||||
|
# Create proper LLMTool with all required fields
|
||||||
|
funcs = [Mock(spec=resource_tool.LLMTool)]
|
||||||
|
funcs[0].name = 'get_weather'
|
||||||
|
funcs[0].description = 'Get weather'
|
||||||
|
|
||||||
|
with patch.object(litellmchat, 'acompletion', new_callable=AsyncMock, return_value=mock_response):
|
||||||
|
result_msg, usage = await requester.invoke_llm(
|
||||||
|
query=None,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
funcs=funcs,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result_msg.tool_calls is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_llm_error_handling(self):
|
||||||
|
"""Test LLM invocation error handling"""
|
||||||
|
mock_ap = Mock()
|
||||||
|
mock_ap.tool_mgr = Mock()
|
||||||
|
mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={})
|
||||||
|
|
||||||
|
model = MockRuntimeModel('gpt-4o', 'test-api-key')
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
|
||||||
|
messages = [provider_message.Message(role='user', content='Hello')]
|
||||||
|
|
||||||
|
error = litellm.AuthenticationError(message='invalid key', model='gpt-4o', llm_provider='openai')
|
||||||
|
|
||||||
|
with patch.object(litellmchat, 'acompletion', new_callable=AsyncMock, side_effect=error):
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
await requester.invoke_llm(
|
||||||
|
query=None,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert 'API key 无效' in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestInvokeEmbedding:
|
||||||
|
"""Test invoke_embedding method"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_embedding_basic(self):
|
||||||
|
"""Test basic embedding invocation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(
|
||||||
|
ap=Mock(),
|
||||||
|
config={
|
||||||
|
'base_url': 'https://api.openai.com/v1',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
model = MockRuntimeEmbeddingModel('text-embedding-3-small', 'test-api-key')
|
||||||
|
|
||||||
|
# Mock LiteLLM embedding response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.data = [
|
||||||
|
Mock(embedding=[0.1, 0.2, 0.3]),
|
||||||
|
Mock(embedding=[0.4, 0.5, 0.6]),
|
||||||
|
]
|
||||||
|
mock_response.usage = Mock()
|
||||||
|
mock_response.usage.prompt_tokens = 20
|
||||||
|
mock_response.usage.completion_tokens = 0
|
||||||
|
mock_response.usage.total_tokens = 20
|
||||||
|
|
||||||
|
with patch.object(litellmchat, 'aembedding', new_callable=AsyncMock, return_value=mock_response):
|
||||||
|
embeddings, usage = await requester.invoke_embedding(
|
||||||
|
model=model,
|
||||||
|
input_text=['Hello', 'World'],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(embeddings) == 2
|
||||||
|
assert embeddings[0] == [0.1, 0.2, 0.3]
|
||||||
|
assert embeddings[1] == [0.4, 0.5, 0.6]
|
||||||
|
assert usage['prompt_tokens'] == 20
|
||||||
|
|
||||||
|
|
||||||
|
class TestInvokeRerank:
|
||||||
|
"""Test invoke_rerank method"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_rerank_basic(self):
|
||||||
|
"""Test basic rerank invocation"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(
|
||||||
|
ap=Mock(),
|
||||||
|
config={
|
||||||
|
'base_url': 'https://api.cohere.ai',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
model = MockRuntimeRerankModel('rerank-english-v3.0', 'test-api-key')
|
||||||
|
|
||||||
|
# Mock LiteLLM rerank response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.results = [
|
||||||
|
{'index': 0, 'relevance_score': 0.95},
|
||||||
|
{'index': 1, 'relevance_score': 0.3},
|
||||||
|
{'index': 2, 'relevance_score': 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(litellmchat, 'arerank', new_callable=AsyncMock, return_value=mock_response):
|
||||||
|
results = await requester.invoke_rerank(
|
||||||
|
model=model,
|
||||||
|
query='What is the capital of France?',
|
||||||
|
documents=['Paris is the capital.', 'London is a city.', 'France is in Europe.'],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
# Scores should be normalized
|
||||||
|
assert results[0]['index'] == 0
|
||||||
|
assert results[0]['relevance_score'] >= 0 and results[0]['relevance_score'] <= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_rerank_normalization(self):
|
||||||
|
"""Test rerank score normalization"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
model = MockRuntimeRerankModel('rerank-english-v3.0', 'test-api-key')
|
||||||
|
|
||||||
|
# Mock response with varying scores
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.results = [
|
||||||
|
{'index': 0, 'relevance_score': 0.9},
|
||||||
|
{'index': 1, 'relevance_score': 0.1},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(litellmchat, 'arerank', new_callable=AsyncMock, return_value=mock_response):
|
||||||
|
results = await requester.invoke_rerank(
|
||||||
|
model=model,
|
||||||
|
query='test query',
|
||||||
|
documents=['doc1', 'doc2'],
|
||||||
|
)
|
||||||
|
|
||||||
|
# After normalization: 0.9 -> 1.0, 0.1 -> 0.0
|
||||||
|
assert results[0]['relevance_score'] == 1.0
|
||||||
|
assert results[1]['relevance_score'] == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_rerank_single_document(self):
|
||||||
|
"""Test rerank with single document (no normalization needed)"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
model = MockRuntimeRerankModel('rerank-english-v3.0', 'test-api-key')
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.results = [
|
||||||
|
{'index': 0, 'relevance_score': 0.5},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(litellmchat, 'arerank', new_callable=AsyncMock, return_value=mock_response):
|
||||||
|
results = await requester.invoke_rerank(
|
||||||
|
model=model,
|
||||||
|
query='test query',
|
||||||
|
documents=['doc1'],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
# Single score stays as is (min==max, no normalization)
|
||||||
|
assert results[0]['relevance_score'] == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertMessages:
|
||||||
|
"""Test _convert_messages method"""
|
||||||
|
|
||||||
|
def test_convert_simple_message(self):
|
||||||
|
"""Test converting simple text message"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
|
||||||
|
messages = [provider_message.Message(role='user', content='Hello')]
|
||||||
|
result = requester._convert_messages(messages)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]['role'] == 'user'
|
||||||
|
assert result[0]['content'] == 'Hello'
|
||||||
|
|
||||||
|
def test_convert_message_with_image_base64(self):
|
||||||
|
"""Test converting message with image_base64 content"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
provider_message.Message(
|
||||||
|
role='user',
|
||||||
|
content=[
|
||||||
|
{'type': 'text', 'text': 'What is in this image?'},
|
||||||
|
{'type': 'image_base64', 'image_base64': 'data:image/png;base64,abc123'},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
result = requester._convert_messages(messages)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
content = result[0]['content']
|
||||||
|
assert isinstance(content, list)
|
||||||
|
# Check image_base64 converted to image_url
|
||||||
|
image_part = [p for p in content if p.get('type') == 'image_url'][0]
|
||||||
|
assert 'image_url' in image_part
|
||||||
|
assert image_part['image_url']['url'] == 'data:image/png;base64,abc123'
|
||||||
|
|
||||||
|
def test_convert_message_with_multiple_text_parts(self):
|
||||||
|
"""Test converting message with multiple text parts (LiteLLM handles this)"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
provider_message.Message(
|
||||||
|
role='user',
|
||||||
|
content=[
|
||||||
|
{'type': 'text', 'text': 'Hello'},
|
||||||
|
{'type': 'text', 'text': 'World'},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
result = requester._convert_messages(messages)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
# LiteLLM handles multiple text parts, we pass them through
|
||||||
|
assert isinstance(result[0]['content'], list)
|
||||||
|
|
||||||
|
|
||||||
|
class TestScanModels:
|
||||||
|
"""Test scan_models method"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_models_basic(self):
|
||||||
|
"""Test basic model scanning"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(
|
||||||
|
ap=Mock(),
|
||||||
|
config={
|
||||||
|
'base_url': 'https://api.openai.com/v1',
|
||||||
|
'timeout': 60,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock httpx response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json = Mock(
|
||||||
|
return_value={
|
||||||
|
'data': [
|
||||||
|
{'id': 'gpt-4o'},
|
||||||
|
{'id': 'text-embedding-3-small'},
|
||||||
|
{'id': 'gpt-3.5-turbo'},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
with patch('httpx.AsyncClient') as mock_client:
|
||||||
|
mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock())
|
||||||
|
mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
result = await requester.scan_models(api_key='test-key')
|
||||||
|
|
||||||
|
assert 'models' in result
|
||||||
|
assert len(result['models']) == 3
|
||||||
|
# Check LLM models are first
|
||||||
|
assert result['models'][0]['type'] == 'llm'
|
||||||
|
# Check embedding model is detected
|
||||||
|
embedding_models = [m for m in result['models'] if m['type'] == 'embedding']
|
||||||
|
assert len(embedding_models) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_models_no_base_url(self):
|
||||||
|
"""Test scan_models without base_url raises error"""
|
||||||
|
requester = litellmchat.LiteLLMRequester(
|
||||||
|
ap=Mock(),
|
||||||
|
config={
|
||||||
|
'base_url': '',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(errors.RequesterError) as exc_info:
|
||||||
|
await requester.scan_models()
|
||||||
|
|
||||||
|
assert 'Base URL required' in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__, '-v'])
|
||||||
Reference in New Issue
Block a user