diff --git a/pyproject.toml b/pyproject.toml index 7e53cb66..0ec0d418 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,7 @@ dependencies = [ "pymilvus>=2.6.4", "pgvector>=0.4.1", "botocore>=1.42.39", + "litellm>=1.0.0", ] keywords = [ "bot", diff --git a/src/langbot/pkg/api/http/controller/groups/monitoring.py b/src/langbot/pkg/api/http/controller/groups/monitoring.py index 11c9e272..65640b6e 100644 --- a/src/langbot/pkg/api/http/controller/groups/monitoring.py +++ b/src/langbot/pkg/api/http/controller/groups/monitoring.py @@ -46,6 +46,30 @@ class MonitoringRouterGroup(group.RouterGroup): return self.success(data=metrics) + @self.route('/token-statistics', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) + async def get_token_statistics() -> str: + """Get detailed token usage statistics (summary, per-model, timeseries).""" + bot_ids = quart.request.args.getlist('botId') + pipeline_ids = quart.request.args.getlist('pipelineId') + start_time_str = quart.request.args.get('startTime') + end_time_str = quart.request.args.get('endTime') + bucket = quart.request.args.get('bucket', 'hour') + if bucket not in ('hour', 'day'): + bucket = 'hour' + + start_time = parse_iso_datetime(start_time_str) + end_time = parse_iso_datetime(end_time_str) + + stats = await self.ap.monitoring_service.get_token_statistics( + bot_ids=bot_ids if bot_ids else None, + pipeline_ids=pipeline_ids if pipeline_ids else None, + start_time=start_time, + end_time=end_time, + bucket=bucket, + ) + + return self.success(data=stats) + @self.route('/messages', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) async def get_messages() -> str: """Get message logs""" diff --git a/src/langbot/pkg/api/http/service/model.py b/src/langbot/pkg/api/http/service/model.py index 320104d8..87298c08 100644 --- a/src/langbot/pkg/api/http/service/model.py +++ b/src/langbot/pkg/api/http/service/model.py @@ -34,6 +34,46 @@ def _runtime_model_data(model_uuid: str, model_data: dict) -> dict: return {**model_data, 'uuid': model_uuid} +async def _validate_provider_supports(ap: app.Application, provider_uuid: str, model_type: str) -> None: + """Validate that the provider's requester declares support for ``model_type``. + + ``model_type`` is one of the manifest ``support_type`` values: + 'llm', 'text-embedding', 'rerank'. Raises ValueError when the requester + manifest does not list the requested type. This is a server-side guard so + a model cannot be attached to a provider that does not support it, even if + the frontend tab restriction is bypassed. + """ + model_mgr = getattr(ap, 'model_mgr', None) + if model_mgr is None: + return + + provider_dict = getattr(model_mgr, 'provider_dict', None) + if not provider_dict: + return + runtime_provider = provider_dict.get(provider_uuid) + if runtime_provider is None: + return + + requester_name = getattr(getattr(runtime_provider, 'provider_entity', None), 'requester', None) + if not requester_name: + return + + get_manifest = getattr(model_mgr, 'get_available_requester_manifest_by_name', None) + if not callable(get_manifest): + return + manifest = get_manifest(requester_name) + if manifest is None: + return + + spec = getattr(manifest, 'spec', None) or {} + support_type = spec.get('support_type') if isinstance(spec, dict) else None + # When a manifest omits support_type, do not block (backward compatible). + if not support_type: + return + if model_type not in support_type: + raise ValueError(f'Provider requester "{requester_name}" does not support {model_type} models') + + class LLMModelsService: ap: app.Application @@ -96,6 +136,8 @@ class LLMModelsService: ) model_data['provider_uuid'] = provider_uuid + await _validate_provider_supports(self.ap, model_data['provider_uuid'], 'llm') + await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_model.LLMModel).values(**model_data)) runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid']) @@ -274,6 +316,8 @@ class EmbeddingModelsService: ) model_data['provider_uuid'] = provider_uuid + await _validate_provider_supports(self.ap, model_data['provider_uuid'], 'text-embedding') + await self.ap.persistence_mgr.execute_async( sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data) ) @@ -434,6 +478,8 @@ class RerankModelsService: ) model_data['provider_uuid'] = provider_uuid + await _validate_provider_supports(self.ap, model_data['provider_uuid'], 'rerank') + await self.ap.persistence_mgr.execute_async( sqlalchemy.insert(persistence_model.RerankModel).values(**model_data) ) diff --git a/src/langbot/pkg/api/http/service/monitoring.py b/src/langbot/pkg/api/http/service/monitoring.py index 1ba66482..3e8e0cde 100644 --- a/src/langbot/pkg/api/http/service/monitoring.py +++ b/src/langbot/pkg/api/http/service/monitoring.py @@ -472,6 +472,179 @@ class MonitoringService: 'active_sessions': active_sessions, } + async def get_token_statistics( + self, + bot_ids: list[str] | None = None, + pipeline_ids: list[str] | None = None, + start_time: datetime.datetime | None = None, + end_time: datetime.datetime | None = None, + bucket: str = 'hour', + ) -> dict: + """Get detailed token usage statistics for production observability. + + Returns: + - summary: aggregate token counters and call/latency stats over the window + - by_model: per-model token + call breakdown (sorted by total tokens desc) + - timeseries: token usage bucketed by `bucket` ('hour' or 'day') + + Only successful LLM calls are counted toward token totals; error calls are + reported separately so a spike in failures is visible without polluting + token accounting. + """ + LLMCall = persistence_monitoring.MonitoringLLMCall + + conditions = [] + if bot_ids: + conditions.append(LLMCall.bot_id.in_(bot_ids)) + if pipeline_ids: + conditions.append(LLMCall.pipeline_id.in_(pipeline_ids)) + if start_time: + conditions.append(LLMCall.timestamp >= start_time) + if end_time: + conditions.append(LLMCall.timestamp <= end_time) + + def _apply(query): + if conditions: + query = query.where(sqlalchemy.and_(*conditions)) + return query + + # ---- Summary aggregates ---- + summary_query = _apply( + sqlalchemy.select( + sqlalchemy.func.count(LLMCall.id), + sqlalchemy.func.coalesce(sqlalchemy.func.sum(LLMCall.input_tokens), 0), + sqlalchemy.func.coalesce(sqlalchemy.func.sum(LLMCall.output_tokens), 0), + sqlalchemy.func.coalesce(sqlalchemy.func.sum(LLMCall.total_tokens), 0), + sqlalchemy.func.coalesce(sqlalchemy.func.sum(LLMCall.duration), 0), + sqlalchemy.func.coalesce(sqlalchemy.func.sum(LLMCall.cost), 0.0), + sqlalchemy.func.sum(sqlalchemy.case((LLMCall.status == 'success', 1), else_=0)), + sqlalchemy.func.sum(sqlalchemy.case((LLMCall.status == 'error', 1), else_=0)), + # Count of successful calls that nonetheless recorded zero tokens — + # a data-quality signal that usage reporting may be broken upstream. + sqlalchemy.func.sum( + sqlalchemy.case( + (sqlalchemy.and_(LLMCall.status == 'success', LLMCall.total_tokens == 0), 1), + else_=0, + ) + ), + ) + ) + summary_result = await self.ap.persistence_mgr.execute_async(summary_query) + row = summary_result.first() + ( + total_calls, + total_input_tokens, + total_output_tokens, + total_tokens, + total_duration, + total_cost, + success_calls, + error_calls, + zero_token_success_calls, + ) = row if row else (0, 0, 0, 0, 0, 0.0, 0, 0, 0) + + total_calls = total_calls or 0 + success_calls = success_calls or 0 + error_calls = error_calls or 0 + zero_token_success_calls = zero_token_success_calls or 0 + + summary = { + 'total_calls': total_calls, + 'success_calls': success_calls, + 'error_calls': error_calls, + 'total_input_tokens': int(total_input_tokens or 0), + 'total_output_tokens': int(total_output_tokens or 0), + 'total_tokens': int(total_tokens or 0), + 'total_cost': round(float(total_cost or 0.0), 6), + 'avg_tokens_per_call': int((total_tokens or 0) / total_calls) if total_calls > 0 else 0, + 'avg_duration_ms': int((total_duration or 0) / total_calls) if total_calls > 0 else 0, + 'avg_tokens_per_second': round((total_output_tokens or 0) / (total_duration / 1000), 2) + if total_duration and total_duration > 0 + else 0, + 'zero_token_success_calls': zero_token_success_calls, + } + + # ---- Per-model breakdown ---- + by_model_query = _apply( + sqlalchemy.select( + LLMCall.model_name, + sqlalchemy.func.count(LLMCall.id), + sqlalchemy.func.coalesce(sqlalchemy.func.sum(LLMCall.input_tokens), 0), + sqlalchemy.func.coalesce(sqlalchemy.func.sum(LLMCall.output_tokens), 0), + sqlalchemy.func.coalesce(sqlalchemy.func.sum(LLMCall.total_tokens), 0), + sqlalchemy.func.coalesce(sqlalchemy.func.sum(LLMCall.duration), 0), + sqlalchemy.func.coalesce(sqlalchemy.func.sum(LLMCall.cost), 0.0), + sqlalchemy.func.sum(sqlalchemy.case((LLMCall.status == 'error', 1), else_=0)), + ).group_by(LLMCall.model_name) + ) + by_model_result = await self.ap.persistence_mgr.execute_async(by_model_query) + by_model = [] + for mrow in by_model_result.all(): + ( + model_name, + m_calls, + m_in, + m_out, + m_total, + m_duration, + m_cost, + m_errors, + ) = mrow + m_calls = m_calls or 0 + by_model.append( + { + 'model_name': model_name, + 'calls': m_calls, + 'error_calls': m_errors or 0, + 'input_tokens': int(m_in or 0), + 'output_tokens': int(m_out or 0), + 'total_tokens': int(m_total or 0), + 'cost': round(float(m_cost or 0.0), 6), + 'avg_tokens_per_call': int((m_total or 0) / m_calls) if m_calls > 0 else 0, + 'avg_duration_ms': int((m_duration or 0) / m_calls) if m_calls > 0 else 0, + } + ) + by_model.sort(key=lambda x: x['total_tokens'], reverse=True) + + # ---- Time-bucketed series ---- + # Use a DB-agnostic bucketing approach: fetch (timestamp, tokens) rows and + # aggregate in Python. The window is bounded by the time filter, so this is + # cheap for typical dashboard ranges (hours/days). + series_query = _apply( + sqlalchemy.select( + LLMCall.timestamp, + LLMCall.input_tokens, + LLMCall.output_tokens, + LLMCall.total_tokens, + ).order_by(LLMCall.timestamp.asc()) + ) + series_result = await self.ap.persistence_mgr.execute_async(series_query) + + bucket_fmt = '%Y-%m-%d %H:00' if bucket == 'hour' else '%Y-%m-%d' + buckets: dict[str, dict] = {} + for srow in series_result.all(): + ts, s_in, s_out, s_total = srow + if ts is None: + continue + key = ts.strftime(bucket_fmt) + b = buckets.setdefault( + key, + {'bucket': key, 'input_tokens': 0, 'output_tokens': 0, 'total_tokens': 0, 'calls': 0}, + ) + b['input_tokens'] += int(s_in or 0) + b['output_tokens'] += int(s_out or 0) + b['total_tokens'] += int(s_total or 0) + b['calls'] += 1 + + timeseries = [buckets[k] for k in sorted(buckets.keys())] + + return { + 'summary': summary, + 'by_model': by_model, + 'timeseries': timeseries, + 'bucket': bucket, + } + async def get_messages( self, bot_ids: list[str] | None = None, diff --git a/src/langbot/pkg/core/bootutils/deps.py b/src/langbot/pkg/core/bootutils/deps.py index 1f653037..2cfd57e0 100644 --- a/src/langbot/pkg/core/bootutils/deps.py +++ b/src/langbot/pkg/core/bootutils/deps.py @@ -42,6 +42,7 @@ required_deps = { 'telegramify_markdown': 'telegramify-markdown', 'slack_sdk': 'slack_sdk', 'asyncpg': 'asyncpg', + 'litellm': 'litellm', } diff --git a/src/langbot/pkg/entity/persistence/model.py b/src/langbot/pkg/entity/persistence/model.py index 3c96acd7..5b5f1fe2 100644 --- a/src/langbot/pkg/entity/persistence/model.py +++ b/src/langbot/pkg/entity/persistence/model.py @@ -31,6 +31,7 @@ class LLMModel(Base): name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) abilities = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[]) + context_length = sqlalchemy.Column(sqlalchemy.Integer, nullable=True) extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) prefered_ranking = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0) created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) diff --git a/src/langbot/pkg/persistence/alembic/versions/0005_add_llm_context_length.py b/src/langbot/pkg/persistence/alembic/versions/0005_add_llm_context_length.py new file mode 100644 index 00000000..20a9d71e --- /dev/null +++ b/src/langbot/pkg/persistence/alembic/versions/0005_add_llm_context_length.py @@ -0,0 +1,39 @@ +"""add llm model context length + +Revision ID: 0005_add_llm_context_length +Revises: 0004_add_mcp_readme +Create Date: 2026-06-07 +""" + +import sqlalchemy as sa +from alembic import op + +revision = '0005_add_llm_context_length' +down_revision = '0004_add_mcp_readme' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Add ``context_length`` to llm_models if the table exists and the column is + # missing. The table may have been created by create_all() with the column + # already present on fresh installs, so guard against duplicate-add; it may + # also be absent entirely (e.g. migrating a truly empty DB), so guard against + # a missing table too. + conn = op.get_bind() + inspector = sa.inspect(conn) + if 'llm_models' not in inspector.get_table_names(): + return + columns = {column['name'] for column in inspector.get_columns('llm_models')} + if 'context_length' not in columns: + op.add_column('llm_models', sa.Column('context_length', sa.Integer(), nullable=True)) + + +def downgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + if 'llm_models' not in inspector.get_table_names(): + return + columns = {column['name'] for column in inspector.get_columns('llm_models')} + if 'context_length' in columns: + op.drop_column('llm_models', 'context_length') diff --git a/src/langbot/pkg/persistence/migrations/dbm026_llm_model_context_length.py b/src/langbot/pkg/persistence/migrations/dbm026_llm_model_context_length.py new file mode 100644 index 00000000..81d7031e --- /dev/null +++ b/src/langbot/pkg/persistence/migrations/dbm026_llm_model_context_length.py @@ -0,0 +1,42 @@ +import sqlalchemy +from .. import migration + + +@migration.migration_class(26) +class DBMigrateLLMModelContextLength(migration.DBMigration): + """Add context_length column to LLM models""" + + async def upgrade(self): + columns = await self._get_columns('llm_models') + if 'context_length' not in columns: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('ALTER TABLE llm_models ADD COLUMN context_length INTEGER') + ) + + async def downgrade(self): + columns = await self._get_columns('llm_models') + if 'context_length' not in columns: + return + + if self.ap.persistence_mgr.db.name == 'postgresql': + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('ALTER TABLE llm_models DROP COLUMN IF EXISTS context_length') + ) + else: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('ALTER TABLE llm_models DROP COLUMN context_length') + ) + + async def _get_columns(self, table_name: str) -> set[str]: + if self.ap.persistence_mgr.db.name == 'postgresql': + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.text(""" + SELECT column_name FROM information_schema.columns + WHERE table_name = :table_name + """), + {'table_name': table_name}, + ) + return {row[0] for row in result.fetchall()} + + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text(f'PRAGMA table_info({table_name})')) + return {row[1] for row in result.fetchall()} diff --git a/src/langbot/pkg/pipeline/preproc/preproc.py b/src/langbot/pkg/pipeline/preproc/preproc.py index 8aa15750..84e9070c 100644 --- a/src/langbot/pkg/pipeline/preproc/preproc.py +++ b/src/langbot/pkg/pipeline/preproc/preproc.py @@ -109,7 +109,7 @@ class PreProcessor(stage.PipelineStage): if llm_model: query.use_llm_model_uuid = llm_model.model_entity.uuid - if llm_model.model_entity.abilities.__contains__('func_call'): + if 'func_call' in (llm_model.model_entity.abilities or []): # Get bound plugins and MCP servers for filtering tools bound_plugins = query.variables.get('_pipeline_bound_plugins', None) bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None) @@ -159,11 +159,7 @@ class PreProcessor(stage.PipelineStage): # Check if this model supports vision, if not, remove all images # TODO this checking should be performed in runner, and in this stage, the image should be reserved - if ( - selected_runner == 'local-agent' - and llm_model - and not llm_model.model_entity.abilities.__contains__('vision') - ): + if selected_runner == 'local-agent' and llm_model and 'vision' not in (llm_model.model_entity.abilities or []): for msg in query.messages: if isinstance(msg.content, list): for me in msg.content: @@ -181,7 +177,7 @@ class PreProcessor(stage.PipelineStage): plain_text += me.text elif isinstance(me, platform_message.Image): if selected_runner != 'local-agent' or ( - llm_model and llm_model.model_entity.abilities.__contains__('vision') + llm_model and 'vision' in (llm_model.model_entity.abilities or []) ): if me.base64 is not None: content_list.append(provider_message.ContentElement.from_image_base64(me.base64)) @@ -202,7 +198,7 @@ class PreProcessor(stage.PipelineStage): content_list.append(provider_message.ContentElement.from_text(msg.text)) elif isinstance(msg, platform_message.Image): if selected_runner != 'local-agent' or ( - llm_model and llm_model.model_entity.abilities.__contains__('vision') + llm_model and 'vision' in (llm_model.model_entity.abilities or []) ): if msg.base64 is not None: content_list.append(provider_message.ContentElement.from_image_base64(msg.base64)) diff --git a/src/langbot/pkg/provider/modelmgr/modelmgr.py b/src/langbot/pkg/provider/modelmgr/modelmgr.py index bcec0683..0c577f1a 100644 --- a/src/langbot/pkg/provider/modelmgr/modelmgr.py +++ b/src/langbot/pkg/provider/modelmgr/modelmgr.py @@ -37,11 +37,41 @@ class ModelManager: self.requester_components = [] self.requester_dict = {} + @staticmethod + def _get_litellm_provider_from_manifest(component: engine.Component | None) -> str | None: + if component is None: + return None + + spec = getattr(component, 'spec', None) or {} + litellm_provider = None + + if isinstance(spec, dict): + litellm_provider = spec.get('litellm_provider') + else: + getter = getattr(spec, 'get', None) + if callable(getter): + try: + litellm_provider = getter('litellm_provider') + except Exception: + litellm_provider = None + + if isinstance(litellm_provider, str) and litellm_provider: + return litellm_provider + return None + async def initialize(self): self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester') requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {} for component in self.requester_components: + # Skip components that use litellm_provider (they will use litellmchat.py instead) + litellm_provider = self._get_litellm_provider_from_manifest(component) + if litellm_provider: + self.ap.logger.debug( + f'Skipping Python class loading for {component.metadata.name} ' + f'(uses litellm_provider={litellm_provider})' + ) + continue requester_dict[component.metadata.name] = component.get_python_component_class() self.requester_dict = requester_dict @@ -236,6 +266,7 @@ class ModelManager: name=model_info.get('name', ''), provider_uuid='', abilities=model_info.get('abilities', []), + context_length=model_info.get('context_length'), extra_args=model_info.get('extra_args', {}), ), provider=runtime_provider, @@ -294,13 +325,37 @@ class ModelManager: else: provider_entity = provider_info - if provider_entity.requester not in self.requester_dict: - raise provider_errors.RequesterNotFoundError(provider_entity.requester) + # Get requester manifest to check for litellm_provider + requester_manifest = self.get_available_requester_manifest_by_name(provider_entity.requester) + litellm_provider = self._get_litellm_provider_from_manifest(requester_manifest) + + # Build config from base_url + config = {'base_url': provider_entity.base_url} + + # Check if requester manifest specifies litellm_provider + if litellm_provider: + from .requesters import litellmchat + + # Use unified LiteLLMRequester with provider prefix + # Map litellm_provider (YAML spec) to custom_llm_provider (config) + config['custom_llm_provider'] = litellm_provider + requester_inst = litellmchat.LiteLLMRequester( + ap=self.ap, + config=config, + ) + self.ap.logger.debug( + f'Using LiteLLMRequester for {provider_entity.requester} ' + f'with custom_llm_provider={config["custom_llm_provider"]}' + ) + else: + # Use original requester class (for backward compatibility) + if provider_entity.requester not in self.requester_dict: + raise provider_errors.RequesterNotFoundError(provider_entity.requester) + requester_inst = self.requester_dict[provider_entity.requester]( + ap=self.ap, + config=config, + ) - requester_inst = self.requester_dict[provider_entity.requester]( - ap=self.ap, - config={'base_url': provider_entity.base_url}, - ) await requester_inst.initialize() token_mgr = token.TokenManager(name=provider_entity.uuid, tokens=provider_entity.api_keys or []) @@ -406,6 +461,7 @@ class ModelManager: name=model_info.get('name', ''), provider_uuid=model_info.get('provider_uuid', ''), abilities=model_info.get('abilities', []), + context_length=model_info.get('context_length'), extra_args=model_info.get('extra_args', {}), ) diff --git a/src/langbot/pkg/provider/modelmgr/requester.py b/src/langbot/pkg/provider/modelmgr/requester.py index cb9a4183..b673c758 100644 --- a/src/langbot/pkg/provider/modelmgr/requester.py +++ b/src/langbot/pkg/provider/modelmgr/requester.py @@ -67,8 +67,8 @@ class RuntimeProvider: if isinstance(result, tuple): msg, usage_info = result if usage_info: - input_tokens = usage_info.get('input_tokens', 0) - output_tokens = usage_info.get('output_tokens', 0) + input_tokens = usage_info.get('prompt_tokens', 0) + output_tokens = usage_info.get('completion_tokens', 0) return msg else: return result @@ -128,7 +128,6 @@ class RuntimeProvider: start_time = time.time() status = 'success' error_message = None - # Note: Stream doesn't easily provide token counts, set to 0 input_tokens = 0 output_tokens = 0 @@ -143,6 +142,15 @@ class RuntimeProvider: remove_think=remove_think, ): yield chunk + # Extract usage from stream if available (stored by LiteLLM requester) + if query: + if query.variables is None: + query.variables = {} + if '_stream_usage' in query.variables: + usage_info = query.variables['_stream_usage'] + input_tokens = usage_info.get('prompt_tokens', 0) + output_tokens = usage_info.get('completion_tokens', 0) + del query.variables['_stream_usage'] except Exception as e: status = 'error' error_message = str(e) diff --git a/src/langbot/pkg/provider/modelmgr/requesters/302aichatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/302aichatcmpl.py deleted file mode 100644 index 40a41718..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/302aichatcmpl.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import typing -import openai - -from . import chatcmpl - - -class AI302ChatCompletions(chatcmpl.OpenAIChatCompletions): - """302.AI ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://api.302.ai/v1', - 'timeout': 120, - } diff --git a/src/langbot/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml index e4f70cae..3cfec198 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: 302.AI icon: 302ai.png spec: + litellm_provider: openai config: - name: base_url label: @@ -22,6 +23,7 @@ spec: type: integer required: true default: 120 + alias: "302ai 302.AI 302 ai 中转 中转站 aggregator gpt claude gemini" support_type: - llm - text-embedding diff --git a/src/langbot/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/src/langbot/pkg/provider/modelmgr/requesters/anthropicmsgs.py deleted file mode 100644 index 1428dc88..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ /dev/null @@ -1,370 +0,0 @@ -from __future__ import annotations - -import typing -import json -import platform -import socket -import anthropic -import httpx - -from .. import errors, requester - -from ....utils import image -import langbot_plugin.api.entities.builtin.resource.tool as resource_tool -import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query -import langbot_plugin.api.entities.builtin.provider.message as provider_message - - -class AnthropicMessages(requester.ProviderAPIRequester): - """Anthropic Messages API 请求器""" - - client: anthropic.AsyncAnthropic - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://api.anthropic.com', - 'timeout': 120, - } - - async def initialize(self): - # 兼容 Windows 缺失 TCP_KEEPINTVL 和 TCP_KEEPCNT 的问题 - if platform.system() == 'Windows': - if not hasattr(socket, 'TCP_KEEPINTVL'): - socket.TCP_KEEPINTVL = 0 - if not hasattr(socket, 'TCP_KEEPCNT'): - socket.TCP_KEEPCNT = 0 - httpx_client = anthropic._base_client.AsyncHttpxClientWrapper( - base_url=self.requester_cfg['base_url'], - # cast to a valid type because mypy doesn't understand our type narrowing - timeout=typing.cast(httpx.Timeout, self.requester_cfg['timeout']), - limits=anthropic._constants.DEFAULT_CONNECTION_LIMITS, - follow_redirects=True, - trust_env=True, - ) - - self.client = anthropic.AsyncAnthropic( - api_key='', - http_client=httpx_client, - base_url=self.requester_cfg['base_url'], - ) - - async def invoke_llm( - self, - query: pipeline_query.Query, - model: requester.RuntimeLLMModel, - messages: typing.List[provider_message.Message], - funcs: typing.List[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> provider_message.Message: - self.client.api_key = model.provider.token_mgr.get_token() - - args = extra_args.copy() - args['model'] = model.model_entity.name - - # 处理消息 - - # system - system_role_message = None - - for i, m in enumerate(messages): - if m.role == 'system': - system_role_message = m - - break - - if system_role_message: - messages.pop(i) - - if isinstance(system_role_message, provider_message.Message) and isinstance(system_role_message.content, str): - args['system'] = system_role_message.content - - req_messages = [] - - for m in messages: - if m.role == 'tool': - tool_call_id = m.tool_call_id - - req_messages.append( - { - 'role': 'user', - 'content': [ - { - 'type': 'tool_result', - 'tool_use_id': tool_call_id, - 'is_error': False, - 'content': [{'type': 'text', 'text': m.content}], - } - ], - } - ) - - continue - - msg_dict = m.dict(exclude_none=True) - - if isinstance(m.content, str) and m.content.strip() != '': - msg_dict['content'] = [{'type': 'text', 'text': m.content}] - elif isinstance(m.content, list): - for i, ce in enumerate(m.content): - if ce.type == 'image_base64': - image_b64, image_format = await image.extract_b64_and_format(ce.image_base64) - - alter_image_ele = { - 'type': 'image', - 'source': { - 'type': 'base64', - 'media_type': f'image/{image_format}', - 'data': image_b64, - }, - } - msg_dict['content'][i] = alter_image_ele - - if m.tool_calls: - for tool_call in m.tool_calls: - msg_dict['content'].append( - { - 'type': 'tool_use', - 'id': tool_call.id, - 'name': tool_call.function.name, - 'input': json.loads(tool_call.function.arguments), - } - ) - - del msg_dict['tool_calls'] - - req_messages.append(msg_dict) - - args['messages'] = req_messages - - if 'thinking' in args: - args['thinking'] = {'type': 'enabled', 'budget_tokens': 10000} - - if funcs: - tools = await self.ap.tool_mgr.generate_tools_for_anthropic(funcs) - - if tools: - args['tools'] = tools - - try: - resp = await self.client.messages.create(**args) - - args = { - 'content': '', - 'role': resp.role, - } - assert type(resp) is anthropic.types.message.Message - - for block in resp.content: - if not remove_think and block.type == 'thinking': - args['content'] = '\n' + block.thinking + '\n\n' + args['content'] - elif block.type == 'text': - args['content'] += block.text - elif block.type == 'tool_use': - assert type(block) is anthropic.types.tool_use_block.ToolUseBlock - tool_call = provider_message.ToolCall( - id=block.id, - type='function', - function=provider_message.FunctionCall(name=block.name, arguments=json.dumps(block.input)), - ) - if 'tool_calls' not in args: - args['tool_calls'] = [] - args['tool_calls'].append(tool_call) - - return provider_message.Message(**args) - except anthropic.AuthenticationError as e: - raise errors.RequesterError(f'api-key 无效: {e.message}') - except anthropic.BadRequestError as e: - raise errors.RequesterError(str(e.message)) - except anthropic.NotFoundError as e: - if 'model: ' in str(e): - raise errors.RequesterError(f'模型无效: {e.message}') - else: - raise errors.RequesterError(f'请求地址无效: {e.message}') - - async def invoke_llm_stream( - self, - query: pipeline_query.Query, - model: requester.RuntimeLLMModel, - messages: typing.List[provider_message.Message], - funcs: typing.List[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> provider_message.Message: - self.client.api_key = model.provider.token_mgr.get_token() - - args = extra_args.copy() - args['model'] = model.model_entity.name - args['stream'] = True - - # 处理消息 - - # system - system_role_message = None - - for i, m in enumerate(messages): - if m.role == 'system': - system_role_message = m - - break - - if system_role_message: - messages.pop(i) - - if isinstance(system_role_message, provider_message.Message) and isinstance(system_role_message.content, str): - args['system'] = system_role_message.content - - req_messages = [] - - for m in messages: - if m.role == 'tool': - tool_call_id = m.tool_call_id - - req_messages.append( - { - 'role': 'user', - 'content': [ - { - 'type': 'tool_result', - 'tool_use_id': tool_call_id, - 'is_error': False, # 暂时直接写false - 'content': [ - {'type': 'text', 'text': m.content} - ], # 这里要是list包裹,应该是多个返回的情况?type类型好像也可以填其他的,暂时只写text - } - ], - } - ) - - continue - - msg_dict = m.dict(exclude_none=True) - - if isinstance(m.content, str) and m.content.strip() != '': - msg_dict['content'] = [{'type': 'text', 'text': m.content}] - elif isinstance(m.content, list): - for i, ce in enumerate(m.content): - if ce.type == 'image_base64': - image_b64, image_format = await image.extract_b64_and_format(ce.image_base64) - - alter_image_ele = { - 'type': 'image', - 'source': { - 'type': 'base64', - 'media_type': f'image/{image_format}', - 'data': image_b64, - }, - } - msg_dict['content'][i] = alter_image_ele - if isinstance(msg_dict['content'], str) and msg_dict['content'] == '': - msg_dict['content'] = [] # 这里不知道为什么会莫名有个空导致content为字符 - if m.tool_calls: - for tool_call in m.tool_calls: - msg_dict['content'].append( - { - 'type': 'tool_use', - 'id': tool_call.id, - 'name': tool_call.function.name, - 'input': json.loads(tool_call.function.arguments), - } - ) - - del msg_dict['tool_calls'] - - req_messages.append(msg_dict) - if 'thinking' in args: - args['thinking'] = {'type': 'enabled', 'budget_tokens': 10000} - - args['messages'] = req_messages - - if funcs: - tools = await self.ap.tool_mgr.generate_tools_for_anthropic(funcs) - - if tools: - args['tools'] = tools - - try: - role = 'assistant' # 默认角色 - # chunk_idx = 0 - think_started = False - think_ended = False - finish_reason = False - tool_name = '' - tool_id = '' - async for chunk in await self.client.messages.create(**args): - content = '' - tool_call = {'id': None, 'function': {'name': None, 'arguments': None}, 'type': 'function'} - if isinstance( - chunk, anthropic.types.raw_content_block_start_event.RawContentBlockStartEvent - ): # 记录开始 - if chunk.content_block.type == 'tool_use': - if chunk.content_block.name is not None: - tool_name = chunk.content_block.name - if chunk.content_block.id is not None: - tool_id = chunk.content_block.id - - tool_call['function']['name'] = tool_name - tool_call['function']['arguments'] = '' - tool_call['id'] = tool_id - - if not remove_think: - if chunk.content_block.type == 'thinking' and not remove_think: - think_started = True - elif chunk.content_block.type == 'text' and chunk.index != 0 and not remove_think: - think_ended = True - continue - elif isinstance(chunk, anthropic.types.raw_content_block_delta_event.RawContentBlockDeltaEvent): - if chunk.delta.type == 'thinking_delta': - if think_started: - think_started = False - content = '\n' + chunk.delta.thinking - elif remove_think: - continue - else: - content = chunk.delta.thinking - elif chunk.delta.type == 'text_delta': - if think_ended: - think_ended = False - content = '\n\n' + chunk.delta.text - else: - content = chunk.delta.text - elif chunk.delta.type == 'input_json_delta': - tool_call['function']['arguments'] = chunk.delta.partial_json - tool_call['function']['name'] = tool_name - tool_call['id'] = tool_id - elif isinstance(chunk, anthropic.types.raw_content_block_stop_event.RawContentBlockStopEvent): - continue # 记录raw_content_block结束的 - - elif isinstance(chunk, anthropic.types.raw_message_delta_event.RawMessageDeltaEvent): - if chunk.delta.stop_reason == 'end_turn': - finish_reason = True - elif isinstance(chunk, anthropic.types.raw_message_stop_event.RawMessageStopEvent): - continue # 这个好像是完全结束 - else: - # print(chunk) - self.ap.logger.debug(f'anthropic chunk: {chunk}') - continue - - args = { - 'content': content, - 'role': role, - 'is_final': finish_reason, - 'tool_calls': None if tool_call['id'] is None else [tool_call], - } - # if chunk_idx == 0: - # chunk_idx += 1 - # continue - - # assert type(chunk) is anthropic.types.message.Chunk - - yield provider_message.MessageChunk(**args) - - # return llm_entities.Message(**args) - except anthropic.AuthenticationError as e: - raise errors.RequesterError(f'api-key 无效: {e.message}') - except anthropic.BadRequestError as e: - raise errors.RequesterError(str(e.message)) - except anthropic.NotFoundError as e: - if 'model: ' in str(e): - raise errors.RequesterError(f'模型无效: {e.message}') - else: - raise errors.RequesterError(f'请求地址无效: {e.message}') diff --git a/src/langbot/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml b/src/langbot/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml index 0ef60d3e..8600f85a 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: Anthropic icon: anthropic.svg spec: + litellm_provider: anthropic config: - name: base_url label: @@ -22,6 +23,7 @@ spec: type: integer required: true default: 120 + alias: "anthropic Anthropic 克劳德 claude Claude Opus Sonnet Haiku 安thropic" support_type: - llm provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/baidu.svg b/src/langbot/pkg/provider/modelmgr/requesters/baidu.svg new file mode 100644 index 00000000..a541c95e --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/baidu.svg @@ -0,0 +1,5 @@ + + + Baidu + ERNIE + diff --git a/src/langbot/pkg/provider/modelmgr/requesters/baiduchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/baiduchatcmpl.yaml new file mode 100644 index 00000000..33af36d5 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/baiduchatcmpl.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: LLMAPIRequester +metadata: + name: baidu-chat-completions + label: + en_US: Baidu ERNIE + zh_Hans: 百度文心一言 + icon: baidu.svg +spec: + litellm_provider: openai + config: + - name: base_url + label: + en_US: Base URL + zh_Hans: 基础 URL + type: string + required: true + default: https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop + - name: timeout + label: + en_US: Timeout + zh_Hans: 超时时间 + type: integer + required: true + default: 120 + alias: "baidu Baidu 百度 千帆 qianfan wenxin 文心 文心一言 ernie ERNIE bce embedding bce-reranker" + support_type: + - llm + - text-embedding + - rerank + provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/bailianchatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/bailianchatcmpl.py deleted file mode 100644 index 9da6e1b4..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/bailianchatcmpl.py +++ /dev/null @@ -1,242 +0,0 @@ -from __future__ import annotations - -import typing -import dashscope -import openai - -from . import modelscopechatcmpl -from .. import requester -import langbot_plugin.api.entities.builtin.resource.tool as resource_tool -import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query -import langbot_plugin.api.entities.builtin.provider.message as provider_message - - -class BailianChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions): - """阿里云百炼大模型平台 ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://dashscope.aliyuncs.com/compatible-mode/v1', - 'timeout': 120, - } - - async def _closure_stream( - self, - query: pipeline_query.Query, - req_messages: list[dict], - use_model: requester.RuntimeLLMModel, - use_funcs: list[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]: - self.client.api_key = use_model.provider.token_mgr.get_token() - - args = {} - args['model'] = use_model.model_entity.name - - if use_funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) - - if tools: - args['tools'] = tools - - # 设置此次请求中的messages - messages = req_messages.copy() - - is_use_dashscope_call = False # 是否使用阿里原生库调用 - is_enable_multi_model = True # 是否支持多轮对话 - use_time_num = 0 # 模型已调用次数,防止存在多文件时重复调用 - use_time_ids = [] # 已调用的ID列表 - message_id = 0 # 记录消息序号 - - for msg in messages: - # print(msg) - if 'content' in msg and isinstance(msg['content'], list): - for me in msg['content']: - if me['type'] == 'image_base64': - me['image_url'] = {'url': me['image_base64']} - me['type'] = 'image_url' - del me['image_base64'] - elif me['type'] == 'file_url' and '.' in me.get('file_name', ''): - # 1. 视频文件推理 - # https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2845871 - file_type = me.get('file_name').lower().split('.')[-1] - if file_type in ['mp4', 'avi', 'mkv', 'mov', 'flv', 'wmv']: - me['type'] = 'video_url' - me['video_url'] = {'url': me['file_url']} - del me['file_url'] - del me['file_name'] - use_time_num += 1 - use_time_ids.append(message_id) - is_enable_multi_model = False - # 2. 语音文件识别, 无法通过openai的audio字段传递,暂时不支持 - # https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2979031 - elif file_type in [ - 'aac', - 'amr', - 'aiff', - 'flac', - 'm4a', - 'mp3', - 'mpeg', - 'ogg', - 'opus', - 'wav', - 'webm', - 'wma', - ]: - me['audio'] = me['file_url'] - me['type'] = 'audio' - del me['file_url'] - del me['type'] - del me['file_name'] - is_use_dashscope_call = True - use_time_num += 1 - use_time_ids.append(message_id) - is_enable_multi_model = False - message_id += 1 - - # 使用列表推导式,保留不在 use_time_ids[:-1] 中的元素,仅保留最后一个多媒体消息 - if not is_enable_multi_model and use_time_num > 1: - messages = [msg for idx, msg in enumerate(messages) if idx not in use_time_ids[:-1]] - - if not is_enable_multi_model: - messages = [msg for msg in messages if 'resp_message_id' not in msg] - - args['messages'] = messages - args['stream'] = True - - # 流式处理状态 - # tool_calls_map: dict[str, provider_message.ToolCall] = {} - chunk_idx = 0 - thinking_started = False - thinking_ended = False - role = 'assistant' # 默认角色 - - if is_use_dashscope_call: - response = dashscope.MultiModalConversation.call( - # 若没有配置环境变量,请用百炼API Key将下行替换为:api_key = "sk-xxx" - api_key=use_model.provider.token_mgr.get_token(), - model=use_model.model_entity.name, - messages=messages, - result_format='message', - asr_options={ - # "language": "zh", # 可选,若已知音频的语种,可通过该参数指定待识别语种,以提升识别准确率 - 'enable_lid': True, - 'enable_itn': False, - }, - stream=True, - ) - content_length_list = [] - previous_length = 0 # 记录上一次的内容长度 - for res in response: - chunk = res['output'] - # 解析 chunk 数据 - if hasattr(chunk, 'choices') and chunk.choices: - choice = chunk.choices[0] - delta_content = choice['message'].content[0]['text'] - finish_reason = choice['finish_reason'] - content_length_list.append(len(delta_content)) - else: - delta_content = '' - finish_reason = None - - # 跳过空的第一个 chunk(只有 role 没有内容) - if chunk_idx == 0 and not delta_content: - chunk_idx += 1 - continue - - # 检查 content_length_list 是否有足够的数据 - if len(content_length_list) >= 2: - now_content = delta_content[previous_length : content_length_list[-1]] - previous_length = content_length_list[-1] # 更新上一次的长度 - else: - now_content = delta_content # 第一次循环时直接使用 delta_content - previous_length = len(delta_content) # 更新上一次的长度 - - # 构建 MessageChunk - 只包含增量内容 - chunk_data = { - 'role': role, - 'content': now_content if now_content else None, - 'is_final': bool(finish_reason) and finish_reason != 'null', - } - - # 移除 None 值 - chunk_data = {k: v for k, v in chunk_data.items() if v is not None} - yield provider_message.MessageChunk(**chunk_data) - chunk_idx += 1 - else: - async for chunk in self._req_stream(args, extra_body=extra_args): - # 解析 chunk 数据 - if hasattr(chunk, 'choices') and chunk.choices: - choice = chunk.choices[0] - delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {} - finish_reason = getattr(choice, 'finish_reason', None) - else: - delta = {} - finish_reason = None - - # 从第一个 chunk 获取 role,后续使用这个 role - if 'role' in delta and delta['role']: - role = delta['role'] - - # 获取增量内容 - delta_content = delta.get('content', '') - reasoning_content = delta.get('reasoning_content', '') - - # 处理 reasoning_content - if reasoning_content: - # accumulated_reasoning += reasoning_content - # 如果设置了 remove_think,跳过 reasoning_content - if remove_think: - chunk_idx += 1 - continue - - # 第一次出现 reasoning_content,添加 开始标签 - if not thinking_started: - thinking_started = True - delta_content = '\n' + reasoning_content - else: - # 继续输出 reasoning_content - delta_content = reasoning_content - elif thinking_started and not thinking_ended and delta_content: - # reasoning_content 结束,normal content 开始,添加 结束标签 - thinking_ended = True - delta_content = '\n\n' + delta_content - - # 处理工具调用增量 - if delta.get('tool_calls'): - for tool_call in delta['tool_calls']: - if tool_call['id'] != '': - tool_id = tool_call['id'] - if tool_call['function']['name'] is not None: - tool_name = tool_call['function']['name'] - - if tool_call['type'] is None: - tool_call['type'] = 'function' - tool_call['id'] = tool_id - tool_call['function']['name'] = tool_name - tool_call['function']['arguments'] = ( - '' if tool_call['function']['arguments'] is None else tool_call['function']['arguments'] - ) - - # 跳过空的第一个 chunk(只有 role 没有内容) - if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'): - chunk_idx += 1 - continue - - # 构建 MessageChunk - 只包含增量内容 - chunk_data = { - 'role': role, - 'content': delta_content if delta_content else None, - 'tool_calls': delta.get('tool_calls'), - 'is_final': bool(finish_reason), - } - - # 移除 None 值 - chunk_data = {k: v for k, v in chunk_data.items() if v is not None} - - yield provider_message.MessageChunk(**chunk_data) - chunk_idx += 1 - # return diff --git a/src/langbot/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml index fc5998c4..75b97b7f 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: 阿里云百炼 icon: bailian.png spec: + litellm_provider: openai config: - name: base_url label: @@ -22,8 +23,10 @@ spec: type: integer required: true default: 120 + alias: "bailian 百炼 阿里 阿里云 aliyun alibaba dashscope 通义 通义千问 qwen Qwen tongyi gte-rerank text-embedding-v" support_type: - llm + - text-embedding - rerank provider_category: maas execution: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py deleted file mode 100644 index e63e362b..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py +++ /dev/null @@ -1,702 +0,0 @@ -from __future__ import annotations - -import asyncio -import typing - -import openai -import openai.types.chat.chat_completion as chat_completion_module -import httpx - -from .. import errors, requester -import langbot_plugin.api.entities.builtin.resource.tool as resource_tool -import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query -import langbot_plugin.api.entities.builtin.provider.message as provider_message - - -class OpenAIChatCompletions(requester.ProviderAPIRequester): - """OpenAI ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://api.openai.com/v1', - 'timeout': 120, - } - - async def initialize(self): - self.client = openai.AsyncClient( - api_key=self.init_api_key, - base_url=self.requester_cfg['base_url'].replace(' ', ''), - timeout=self.requester_cfg['timeout'], - http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']), - ) - - def _mask_api_key(self, api_key: str | None) -> str: - if not api_key: - return '' - if len(api_key) <= 8: - return '****' - return f'{api_key[:4]}...{api_key[-4:]}' - - def _infer_model_type(self, model_id: str) -> str: - normalized_model_id = (model_id or '').lower() - embedding_keywords = ( - 'embedding', - 'embed', - 'bge-', - 'e5-', - 'm3e', - 'gte-', - 'multilingual-e5', - 'text-embedding', - ) - return 'embedding' if any(keyword in normalized_model_id for keyword in embedding_keywords) else 'llm' - - def _infer_model_abilities(self, item: dict[str, typing.Any], model_id: str) -> list[str]: - normalized_model_id = (model_id or '').lower() - abilities: set[str] = set() - - def _flatten(value: typing.Any) -> list[str]: - if value is None: - return [] - if isinstance(value, str): - return [value.lower()] - if isinstance(value, dict): - flattened: list[str] = [] - for nested_value in value.values(): - flattened.extend(_flatten(nested_value)) - return flattened - if isinstance(value, (list, tuple, set)): - flattened: list[str] = [] - for nested_value in value: - flattened.extend(_flatten(nested_value)) - return flattened - return [str(value).lower()] - - capability_tokens = _flatten(item.get('capabilities')) - capability_tokens.extend(_flatten(item.get('modalities'))) - capability_tokens.extend(_flatten(item.get('input_modalities'))) - capability_tokens.extend(_flatten(item.get('output_modalities'))) - capability_tokens.extend(_flatten(item.get('supported_generation_methods'))) - capability_tokens.extend(_flatten(item.get('supported_parameters'))) - capability_tokens.extend(_flatten(item.get('architecture'))) - - combined_tokens = capability_tokens + [normalized_model_id] - - vision_keywords = ( - 'vision', - 'image', - 'file', - 'video', - 'multimodal', - 'vl', - 'ocr', - 'omni', - ) - function_call_keywords = ( - 'function', - 'tool', - 'tools', - 'tool_choice', - 'tool_call', - 'tool-use', - 'tool_use', - ) - - if any(any(keyword in token for keyword in vision_keywords) for token in combined_tokens): - abilities.add('vision') - - if any(any(keyword in token for keyword in function_call_keywords) for token in combined_tokens): - abilities.add('func_call') - - return sorted(abilities) - - def _normalize_modalities(self, value: typing.Any) -> list[str]: - normalized: list[str] = [] - - def _collect(item: typing.Any): - if item is None: - return - if isinstance(item, str): - for part in item.replace('->', ',').replace('+', ',').split(','): - token = part.strip().lower() - if token and token not in normalized: - normalized.append(token) - return - if isinstance(item, dict): - for nested in item.values(): - _collect(nested) - return - if isinstance(item, (list, tuple, set)): - for nested in item: - _collect(nested) - return - - _collect(value) - return normalized - - def _extract_scan_metadata(self, item: dict[str, typing.Any], model_id: str) -> dict[str, typing.Any]: - display_name = item.get('name') - if not isinstance(display_name, str) or not display_name.strip() or display_name == model_id: - display_name = '' - - description = item.get('description') - if not isinstance(description, str) or not description.strip(): - description = '' - - context_length = item.get('context_length') - if context_length is None and isinstance(item.get('top_provider'), dict): - context_length = item['top_provider'].get('context_length') - - if not isinstance(context_length, int): - try: - context_length = int(context_length) if context_length is not None else None - except (TypeError, ValueError): - context_length = None - - input_modalities = self._normalize_modalities(item.get('input_modalities')) - output_modalities = self._normalize_modalities(item.get('output_modalities')) - - if isinstance(item.get('architecture'), dict): - if not input_modalities: - input_modalities = self._normalize_modalities(item['architecture'].get('input_modalities')) - if not output_modalities: - output_modalities = self._normalize_modalities(item['architecture'].get('output_modalities')) - - owned_by = item.get('owned_by') - if not isinstance(owned_by, str) or not owned_by.strip(): - owned_by = '' - - return { - 'display_name': display_name or None, - 'description': description or None, - 'context_length': context_length, - 'owned_by': owned_by or None, - 'input_modalities': input_modalities, - 'output_modalities': output_modalities, - } - - async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]: - headers = {} - if api_key: - headers['Authorization'] = f'Bearer {api_key}' - - models_url = f'{self.requester_cfg["base_url"].rstrip("/")}/models' - async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client: - response = await client.get(models_url, headers=headers) - response.raise_for_status() - payload = response.json() - - models = [] - for item in payload.get('data', []): - model_id = item.get('id') - if not model_id: - continue - models.append( - { - 'id': model_id, - 'name': model_id, - 'type': self._infer_model_type(model_id), - 'abilities': self._infer_model_abilities(item, model_id), - **self._extract_scan_metadata(item, model_id), - } - ) - - models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower())) - return { - 'models': models, - 'debug': { - 'request': { - 'method': 'GET', - 'url': models_url, - 'headers': { - 'Authorization': f'Bearer {self._mask_api_key(api_key)}' if api_key else '', - }, - }, - 'response': payload, - }, - } - - async def _req( - self, - args: dict, - extra_body: dict = {}, - ) -> chat_completion_module.ChatCompletion: - return await self.client.chat.completions.create(**args, extra_body=extra_body) - - async def _req_stream( - self, - args: dict, - extra_body: dict = {}, - ): - async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body): - yield chunk - - async def _make_msg( - self, - chat_completion: chat_completion_module.ChatCompletion, - remove_think: bool = False, - ) -> provider_message.Message: - if not isinstance(chat_completion, chat_completion_module.ChatCompletion): - raise TypeError(f'Expected ChatCompletion, got {type(chat_completion).__name__}: {chat_completion[:16]}') - - chatcmpl_message = chat_completion.choices[0].message.model_dump() - - # 确保 role 字段存在且不为 None - if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None: - chatcmpl_message['role'] = 'assistant' - - # 处理思维链 - content = chatcmpl_message.get('content', '') - reasoning_content = chatcmpl_message.get('reasoning_content', None) - - processed_content, _ = await self._process_thinking_content( - content=content, reasoning_content=reasoning_content, remove_think=remove_think - ) - - chatcmpl_message['content'] = processed_content - - # 移除 reasoning_content 字段,避免传递给 Message - if 'reasoning_content' in chatcmpl_message: - del chatcmpl_message['reasoning_content'] - - message = provider_message.Message(**chatcmpl_message) - - return message - - async def _process_thinking_content( - self, - content: str, - reasoning_content: str = None, - remove_think: bool = False, - ) -> tuple[str, str]: - """处理思维链内容 - - Args: - content: 原始内容 - reasoning_content: reasoning_content 字段内容 - remove_think: 是否移除思维链 - - Returns: - (处理后的内容, 提取的思维链内容) - """ - thinking_content = '' - - # 1. 从 reasoning_content 提取思维链 - if reasoning_content: - thinking_content = reasoning_content - - # 2. 从 content 中提取 标签内容 - if content and '' in content and '' in content: - import re - - think_pattern = r'(.*?)' - think_matches = re.findall(think_pattern, content, re.DOTALL) - if think_matches: - # 如果已有 reasoning_content,则追加 - if thinking_content: - thinking_content += '\n' + '\n'.join(think_matches) - else: - thinking_content = '\n'.join(think_matches) - # 移除 content 中的 标签 - content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip() - - # 3. 根据 remove_think 参数决定是否保留思维链 - if remove_think: - return content, '' - else: - # 如果有思维链内容,将其以 格式添加到 content 开头 - if thinking_content: - content = f'\n{thinking_content}\n\n{content}'.strip() - return content, thinking_content - - async def _closure_stream( - self, - query: pipeline_query.Query, - req_messages: list[dict], - use_model: requester.RuntimeLLMModel, - use_funcs: list[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> provider_message.MessageChunk: - self.client.api_key = use_model.provider.token_mgr.get_token() - - args = {} - args['model'] = use_model.model_entity.name - - if use_funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) - if tools: - args['tools'] = tools - - # 设置此次请求中的messages - messages = req_messages.copy() - - # 检查vision - for msg in messages: - if 'content' in msg and isinstance(msg['content'], list): - for me in msg['content']: - if me['type'] == 'image_base64': - me['image_url'] = {'url': me['image_base64']} - me['type'] = 'image_url' - del me['image_base64'] - - args['messages'] = messages - args['stream'] = True - - # 流式处理状态 - # tool_calls_map: dict[str, provider_message.ToolCall] = {} - chunk_idx = 0 - thinking_started = False - thinking_ended = False - role = 'assistant' # 默认角色 - tool_id = '' - tool_name = '' - # accumulated_reasoning = '' # 仅用于判断何时结束思维链 - - async for chunk in self._req_stream(args, extra_body=extra_args): - # 解析 chunk 数据 - - if hasattr(chunk, 'choices') and chunk.choices: - choice = chunk.choices[0] - delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {} - - finish_reason = getattr(choice, 'finish_reason', None) - else: - delta = {} - finish_reason = None - # 从第一个 chunk 获取 role,后续使用这个 role - if 'role' in delta and delta['role']: - role = delta['role'] - - # 获取增量内容 - delta_content = delta.get('content', '') - reasoning_content = delta.get('reasoning_content', '') - - # 处理 reasoning_content - if reasoning_content: - # accumulated_reasoning += reasoning_content - # 如果设置了 remove_think,跳过 reasoning_content - if remove_think: - chunk_idx += 1 - continue - - # 第一次出现 reasoning_content,添加 开始标签 - if not thinking_started: - thinking_started = True - delta_content = '\n' + reasoning_content - else: - # 继续输出 reasoning_content - delta_content = reasoning_content - elif thinking_started and not thinking_ended and delta_content: - # reasoning_content 结束,normal content 开始,添加 结束标签 - thinking_ended = True - delta_content = '\n\n' + delta_content - - # 处理 content 中已有的 标签(如果需要移除) - # if delta_content and remove_think and '' in delta_content: - # import re - # - # # 移除 标签及其内容 - # delta_content = re.sub(r'.*?', '', delta_content, flags=re.DOTALL) - - # 处理工具调用增量 - # delta_tool_calls = None - if delta.get('tool_calls'): - for tool_call in delta['tool_calls']: - if tool_call['id'] and tool_call['function']['name']: - tool_id = tool_call['id'] - tool_name = tool_call['function']['name'] - else: - tool_call['id'] = tool_id - tool_call['function']['name'] = tool_name - if tool_call['type'] is None: - tool_call['type'] = 'function' - - # 跳过空的第一个 chunk(只有 role 没有内容) - if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'): - chunk_idx += 1 - continue - # 构建 MessageChunk - 只包含增量内容 - chunk_data = { - 'role': role, - 'content': delta_content if delta_content else None, - 'tool_calls': delta.get('tool_calls'), - 'is_final': bool(finish_reason), - } - - # 移除 None 值 - chunk_data = {k: v for k, v in chunk_data.items() if v is not None} - - yield provider_message.MessageChunk(**chunk_data) - chunk_idx += 1 - - async def _closure( - self, - query: pipeline_query.Query, - req_messages: list[dict], - use_model: requester.RuntimeLLMModel, - use_funcs: list[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> tuple[provider_message.Message, dict]: - self.client.api_key = use_model.provider.token_mgr.get_token() - - args = {} - args['model'] = use_model.model_entity.name - - if use_funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) - - if tools: - args['tools'] = tools - - # 设置此次请求中的messages - messages = req_messages.copy() - - # 检查vision - for msg in messages: - if 'content' in msg and isinstance(msg['content'], list): - for me in msg['content']: - if me['type'] == 'image_base64': - me['image_url'] = {'url': me['image_base64']} - me['type'] = 'image_url' - del me['image_base64'] - - args['messages'] = messages - - # 发送请求 - - resp = await self._req(args, extra_body=extra_args) - # 处理请求结果 - message = await self._make_msg(resp, remove_think) - - # Extract token usage from response - usage_info = {} - if hasattr(resp, 'usage') and resp.usage: - usage_info['input_tokens'] = resp.usage.prompt_tokens or 0 - usage_info['output_tokens'] = resp.usage.completion_tokens or 0 - usage_info['total_tokens'] = resp.usage.total_tokens or 0 - - return message, usage_info - - async def invoke_llm( - self, - query: pipeline_query.Query, - model: requester.RuntimeLLMModel, - messages: typing.List[provider_message.Message], - funcs: typing.List[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> tuple[provider_message.Message, dict]: - """Invoke LLM and return message with usage info""" - req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 - for m in messages: - msg_dict = m.dict(exclude_none=True) - content = msg_dict.get('content') - if isinstance(content, list): - # 检查 content 列表中是否每个部分都是文本 - if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): - # 将所有文本部分合并为一个字符串 - msg_dict['content'] = '\n'.join(part['text'] for part in content) - req_messages.append(msg_dict) - - try: - msg, usage_info = await self._closure( - query=query, - req_messages=req_messages, - use_model=model, - use_funcs=funcs, - extra_args=extra_args, - remove_think=remove_think, - ) - return msg, usage_info - except asyncio.TimeoutError: - raise errors.RequesterError('请求超时') - except openai.BadRequestError as e: - error_message = str(e.message) if hasattr(e, 'message') else str(e) - if 'context_length_exceeded' in str(e): - raise errors.RequesterError(f'上文过长,请重置会话: {error_message}') - else: - raise errors.RequesterError(f'请求参数错误: {error_message}') - except openai.AuthenticationError as e: - error_message = str(e.message) if hasattr(e, 'message') else str(e) - raise errors.RequesterError(f'无效的 api-key: {error_message}') - except openai.NotFoundError as e: - error_message = str(e.message) if hasattr(e, 'message') else str(e) - raise errors.RequesterError(f'请求路径错误: {error_message}') - except openai.RateLimitError as e: - error_message = str(e.message) if hasattr(e, 'message') else str(e) - raise errors.RequesterError(f'请求过于频繁或余额不足: {error_message}') - except openai.APIConnectionError as e: - error_message = f'连接错误: {str(e)}' - raise errors.RequesterError(error_message) - except openai.APIError as e: - error_message = str(e.message) if hasattr(e, 'message') else str(e) - raise errors.RequesterError(f'请求错误: {error_message}') - - async def invoke_embedding( - self, - model: requester.RuntimeEmbeddingModel, - input_text: list[str], - extra_args: dict[str, typing.Any] = {}, - ) -> tuple[list[list[float]], dict]: - """调用 Embedding API, returns (embeddings, usage_info)""" - self.client.api_key = model.provider.token_mgr.get_token() - - args = { - 'model': model.model_entity.name, - 'input': input_text, - } - - if model.model_entity.extra_args: - args.update(model.model_entity.extra_args) - - args.update(extra_args) - - try: - resp = await self.client.embeddings.create(**args) - - # Extract usage info - usage_info = {} - if hasattr(resp, 'usage') and resp.usage: - usage_info['prompt_tokens'] = resp.usage.prompt_tokens or 0 - usage_info['total_tokens'] = resp.usage.total_tokens or 0 - - return [d.embedding for d in resp.data], usage_info - except asyncio.TimeoutError: - raise errors.RequesterError('请求超时') - except openai.BadRequestError as e: - raise errors.RequesterError(f'请求参数错误: {e.message}') - - async def invoke_llm_stream( - self, - query: pipeline_query.Query, - model: requester.RuntimeLLMModel, - messages: typing.List[provider_message.Message], - funcs: typing.List[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> provider_message.MessageChunk: - req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 - for m in messages: - msg_dict = m.dict(exclude_none=True) - content = msg_dict.get('content') - if isinstance(content, list): - # 检查 content 列表中是否每个部分都是文本 - if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): - # 将所有文本部分合并为一个字符串 - msg_dict['content'] = '\n'.join(part['text'] for part in content) - req_messages.append(msg_dict) - - try: - async for item in self._closure_stream( - query=query, - req_messages=req_messages, - use_model=model, - use_funcs=funcs, - extra_args=extra_args, - remove_think=remove_think, - ): - yield item - - except asyncio.TimeoutError: - raise errors.RequesterError('请求超时') - except openai.BadRequestError as e: - if 'context_length_exceeded' in e.message: - raise errors.RequesterError(f'上文过长,请重置会话: {e.message}') - else: - raise errors.RequesterError(f'请求参数错误: {e.message}') - except openai.AuthenticationError as e: - raise errors.RequesterError(f'无效的 api-key: {e.message}') - except openai.NotFoundError as e: - raise errors.RequesterError(f'请求路径错误: {e.message}') - except openai.RateLimitError as e: - raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') - except openai.APIError as e: - raise errors.RequesterError(f'请求错误: {e.message}') - - async def invoke_rerank( - self, - model: requester.RuntimeRerankModel, - query: str, - documents: typing.List[str], - extra_args: dict[str, typing.Any] = {}, - ) -> typing.List[dict]: - """Standard /rerank endpoint (Jina/Cohere/SiliconFlow/Voyage/DashScope compatible) - - Supports extra_args from model.extra_args: - - rerank_url: full URL override (e.g. "https://dashscope.aliyuncs.com/compatible-api/v1/reranks") - - rerank_path: path override appended to base_url (e.g. "reranks" instead of default "rerank") - - Any other fields are merged into the request payload. - """ - api_key = model.provider.token_mgr.get_token() - base_url = self.requester_cfg.get('base_url', '').rstrip('/') - timeout = self.requester_cfg.get('timeout', 120) - - merged_args = {} - if model.model_entity.extra_args: - merged_args.update(model.model_entity.extra_args) - if extra_args: - merged_args.update(extra_args) - - rerank_url = merged_args.pop('rerank_url', None) - rerank_path = merged_args.pop('rerank_path', 'rerank') - if not rerank_url: - rerank_url = f'{base_url}/{rerank_path}' - - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {api_key}', - } - - payload = { - 'model': model.model_entity.name, - 'query': query, - 'documents': documents[:64], - 'top_n': min(len(documents), 64), - } - - if merged_args: - payload.update(merged_args) - - try: - async with httpx.AsyncClient(trust_env=True, timeout=timeout) as client: - resp = await client.post(rerank_url, headers=headers, json=payload) - resp.raise_for_status() - data = resp.json() - - results = self._parse_rerank_response(data) - - if results: - scores = [r.get('relevance_score', 0.0) for r in results] - min_score = min(scores) - max_score = max(scores) - if max_score - min_score > 1e-6: - for r in results: - r['relevance_score'] = (r['relevance_score'] - min_score) / (max_score - min_score) - - return results - except httpx.HTTPStatusError as e: - raise errors.RequesterError(f'Rerank request failed: {e.response.status_code} - {e.response.text}') - except httpx.TimeoutException: - raise errors.RequesterError('Rerank request timed out') - except Exception as e: - raise errors.RequesterError(f'Rerank request error: {str(e)}') - - @staticmethod - def _parse_rerank_response(data: dict) -> typing.List[dict]: - """Parse rerank response from various providers. - - Handles: - - Jina/Cohere/SiliconFlow: {"results": [{"index", "relevance_score"}]} - - Voyage AI: {"data": [{"index", "relevance_score"}]} - - DashScope: {"output": {"results": [{"index", "relevance_score"}]}} - """ - if 'results' in data: - return data['results'] - if 'data' in data: - return data['data'] - if 'output' in data and isinstance(data['output'], dict): - return data['output'].get('results', []) - return [] diff --git a/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.yaml index 21bd6a05..526721b0 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: OpenAI icon: openai.svg spec: + litellm_provider: openai config: - name: base_url label: @@ -22,10 +23,10 @@ spec: type: integer required: true default: 120 + alias: "openai OpenAI 欧派 gpt GPT ChatGPT chatgpt o1 o3 o4 text-embedding 通用 openai兼容 compatible" support_type: - llm - text-embedding - - rerank provider_category: manufacturer execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/chromaembed.yaml b/src/langbot/pkg/provider/modelmgr/requesters/chromaembed.yaml index 396b8c16..51f2f821 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/chromaembed.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/chromaembed.yaml @@ -12,6 +12,7 @@ metadata: icon: chroma.svg spec: config: [] + alias: "chroma Chroma 向量 vector embedding 嵌入 chromadb" support_type: - text-embedding provider_category: builtin diff --git a/src/langbot/pkg/provider/modelmgr/requesters/coherererank.yaml b/src/langbot/pkg/provider/modelmgr/requesters/coherererank.yaml index f1ca209b..e67651bc 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/coherererank.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/coherererank.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: Cohere icon: cohere.svg spec: + litellm_provider: cohere config: - name: base_url label: @@ -22,6 +23,7 @@ spec: type: integer required: true default: 120 + alias: "cohere Cohere rerank 重排 reranker rerank-english rerank-multilingual command" support_type: - rerank provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/compsharechatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/compsharechatcmpl.py deleted file mode 100644 index d272e721..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/compsharechatcmpl.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import typing -import openai - -from . import chatcmpl - - -class CompShareChatCompletions(chatcmpl.OpenAIChatCompletions): - """CompShare ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://api.modelverse.cn/v1', - 'timeout': 120, - } diff --git a/src/langbot/pkg/provider/modelmgr/requesters/compsharechatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/compsharechatcmpl.yaml index 92fcafdc..843ac9be 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/compsharechatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/compsharechatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: 优云智算 icon: compshare.png spec: + litellm_provider: openai config: - name: base_url label: @@ -22,8 +23,11 @@ spec: type: integer required: true default: 120 + alias: "compshare 优刻得 ucloud UCloud 算力 共享算力 GPU" support_type: - llm + - text-embedding + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py deleted file mode 100644 index 5bcbd40c..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -import typing - -from . import chatcmpl -from .. import errors, requester -import langbot_plugin.api.entities.builtin.resource.tool as resource_tool -import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query -import langbot_plugin.api.entities.builtin.provider.message as provider_message - - -class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): - """Deepseek ChatCompletion API 请求器""" - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://api.deepseek.com', - 'timeout': 120, - } - - async def _closure( - self, - query: pipeline_query.Query, - req_messages: list[dict], - use_model: requester.RuntimeLLMModel, - use_funcs: list[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> tuple[provider_message.Message, dict]: - self.client.api_key = use_model.provider.token_mgr.get_token() - - args = {} - args['model'] = use_model.model_entity.name - - if use_funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) - - if tools: - args['tools'] = tools - - # 设置此次请求中的messages - messages = req_messages - - # deepseek 不支持多模态,把content都转换成纯文字 - for m in messages: - if 'content' in m and isinstance(m['content'], list): - m['content'] = ' '.join([c['text'] for c in m['content'] if 'text' in c]) - - args['messages'] = messages - - # 发送请求 - resp = await self._req(args, extra_body=extra_args) - - # print(resp) - - if resp is None: - raise errors.RequesterError('接口返回为空,请确定模型提供商服务是否正常') - # 处理请求结果 - message = await self._make_msg(resp, remove_think) - - # Extract token usage from response - usage_info = {} - if hasattr(resp, 'usage') and resp.usage: - usage_info['input_tokens'] = resp.usage.prompt_tokens or 0 - usage_info['output_tokens'] = resp.usage.completion_tokens or 0 - usage_info['total_tokens'] = resp.usage.total_tokens or 0 - - return message, usage_info diff --git a/src/langbot/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml index 8ef1fcf9..46604670 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: DeepSeek icon: deepseek.svg spec: + litellm_provider: deepseek config: - name: base_url label: @@ -22,6 +23,7 @@ spec: type: integer required: true default: 120 + alias: "deepseek DeepSeek 深度求索 深度 求索 dpsk v3 r1 deepseek-chat deepseek-reasoner" support_type: - llm provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/doubao.svg b/src/langbot/pkg/provider/modelmgr/requesters/doubao.svg new file mode 100644 index 00000000..e47c7232 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/doubao.svg @@ -0,0 +1,4 @@ + + + 豆包 + diff --git a/src/langbot/pkg/provider/modelmgr/requesters/doubaochatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/doubaochatcmpl.yaml new file mode 100644 index 00000000..b6cb72c9 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/doubaochatcmpl.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: LLMAPIRequester +metadata: + name: doubao-chat-completions + label: + en_US: ByteDance Doubao + zh_Hans: 字节豆包 + icon: doubao.svg +spec: + litellm_provider: openai + config: + - name: base_url + label: + en_US: Base URL + zh_Hans: 基础 URL + type: string + required: true + default: https://ark.cn-beijing.volces.com/api/v3 + - name: timeout + label: + en_US: Timeout + zh_Hans: 超时时间 + type: integer + required: true + default: 120 + alias: "doubao 豆包 字节 字节跳动 bytedance volcengine 火山 火山引擎 ark 方舟 seed" + support_type: + - llm + - text-embedding + - rerank + provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/geminichatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/geminichatcmpl.py deleted file mode 100644 index 956b49f6..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/geminichatcmpl.py +++ /dev/null @@ -1,205 +0,0 @@ -from __future__ import annotations - -import typing -import httpx - -from . import chatcmpl - -import uuid - -from .. import requester -import langbot_plugin.api.entities.builtin.provider.message as provider_message -import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query -import langbot_plugin.api.entities.builtin.resource.tool as resource_tool - - -class GeminiChatCompletions(chatcmpl.OpenAIChatCompletions): - """Google Gemini API 请求器""" - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://generativelanguage.googleapis.com/v1beta/openai', - 'timeout': 120, - } - - async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]: - models_url = 'https://generativelanguage.googleapis.com/v1beta/models' - params = {'key': api_key} if api_key else {} - - all_models: list[dict[str, typing.Any]] = [] - next_page_token = '' - last_payload: dict[str, typing.Any] = {} - - async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client: - while True: - request_params = dict(params) - if next_page_token: - request_params['pageToken'] = next_page_token - - response = await client.get(models_url, params=request_params) - response.raise_for_status() - payload = response.json() - last_payload = payload - - for item in payload.get('models', []): - model_name = item.get('name', '') - model_id = model_name.replace('models/', '', 1) - if not model_id: - continue - - supported_methods = item.get('supportedGenerationMethods', []) or [] - if 'embedContent' in supported_methods and 'generateContent' not in supported_methods: - model_type = 'embedding' - else: - model_type = 'llm' - - all_models.append( - { - 'id': model_id, - 'name': model_id, - 'type': model_type, - 'abilities': self._infer_model_abilities(item, model_id), - 'display_name': item.get('displayName') or None, - 'description': item.get('description') or None, - 'context_length': item.get('inputTokenLimit'), - 'input_modalities': self._normalize_modalities(item.get('inputModalities')), - 'output_modalities': self._normalize_modalities(item.get('outputModalities')), - } - ) - - next_page_token = payload.get('nextPageToken', '') - if not next_page_token: - break - - all_models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower())) - return { - 'models': all_models, - 'debug': { - 'request': { - 'method': 'GET', - 'url': models_url, - 'query': {'key': self._mask_api_key(api_key)} if api_key else {}, - }, - 'response': last_payload, - }, - } - - async def _closure_stream( - self, - query: pipeline_query.Query, - req_messages: list[dict], - use_model: requester.RuntimeLLMModel, - use_funcs: list[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> provider_message.MessageChunk: - self.client.api_key = use_model.provider.token_mgr.get_token() - - args = {} - args['model'] = use_model.model_entity.name - - if use_funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) - if tools: - args['tools'] = tools - - # 设置此次请求中的messages - messages = req_messages.copy() - - # 检查vision - for msg in messages: - if 'content' in msg and isinstance(msg['content'], list): - for me in msg['content']: - if me['type'] == 'image_base64': - me['image_url'] = {'url': me['image_base64']} - me['type'] = 'image_url' - del me['image_base64'] - - args['messages'] = messages - args['stream'] = True - - # 流式处理状态 - # tool_calls_map: dict[str, provider_message.ToolCall] = {} - chunk_idx = 0 - thinking_started = False - thinking_ended = False - role = 'assistant' # 默认角色 - tool_id = '' - tool_name = '' - # accumulated_reasoning = '' # 仅用于判断何时结束思维链 - - async for chunk in self._req_stream(args, extra_body=extra_args): - # 解析 chunk 数据 - - if hasattr(chunk, 'choices') and chunk.choices: - choice = chunk.choices[0] - delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {} - - finish_reason = getattr(choice, 'finish_reason', None) - else: - delta = {} - finish_reason = None - # 从第一个 chunk 获取 role,后续使用这个 role - if 'role' in delta and delta['role']: - role = delta['role'] - - # 获取增量内容 - delta_content = delta.get('content', '') - reasoning_content = delta.get('reasoning_content', '') - - # 处理 reasoning_content - if reasoning_content: - # accumulated_reasoning += reasoning_content - # 如果设置了 remove_think,跳过 reasoning_content - if remove_think: - chunk_idx += 1 - continue - - # 第一次出现 reasoning_content,添加 开始标签 - if not thinking_started: - thinking_started = True - delta_content = '\n' + reasoning_content - else: - # 继续输出 reasoning_content - delta_content = reasoning_content - elif thinking_started and not thinking_ended and delta_content: - # reasoning_content 结束,normal content 开始,添加 结束标签 - thinking_ended = True - delta_content = '\n\n' + delta_content - - # 处理 content 中已有的 标签(如果需要移除) - # if delta_content and remove_think and '' in delta_content: - # import re - # - # # 移除 标签及其内容 - # delta_content = re.sub(r'.*?', '', delta_content, flags=re.DOTALL) - - # 处理工具调用增量 - # delta_tool_calls = None - if delta.get('tool_calls'): - for tool_call in delta['tool_calls']: - if tool_call['id'] == '' and tool_id == '': - tool_id = str(uuid.uuid4()) - if tool_call['function']['name']: - tool_name = tool_call['function']['name'] - tool_call['id'] = tool_id - tool_call['function']['name'] = tool_name - if tool_call['type'] is None: - tool_call['type'] = 'function' - - # 跳过空的第一个 chunk(只有 role 没有内容) - if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'): - chunk_idx += 1 - continue - # 构建 MessageChunk - 只包含增量内容 - chunk_data = { - 'role': role, - 'content': delta_content if delta_content else None, - 'tool_calls': delta.get('tool_calls'), - 'is_final': bool(finish_reason), - } - - # 移除 None 值 - chunk_data = {k: v for k, v in chunk_data.items() if v is not None} - - yield provider_message.MessageChunk(**chunk_data) - chunk_idx += 1 diff --git a/src/langbot/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml index fdebe9b9..68a81a8b 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: Google Gemini icon: gemini.svg spec: + litellm_provider: gemini config: - name: base_url label: @@ -22,8 +23,10 @@ spec: type: integer required: true default: 120 + alias: "gemini Gemini 谷歌 google Google 双子座 bard flash pro text-embedding-004" support_type: - llm + - text-embedding provider_category: manufacturer execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py deleted file mode 100644 index 4e295e9f..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - - -import typing - -from . import ppiochatcmpl - - -class GiteeAIChatCompletions(ppiochatcmpl.PPIOChatCompletions): - """Gitee AI ChatCompletions API 请求器""" - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://ai.gitee.com/v1', - 'timeout': 120, - } diff --git a/src/langbot/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml index b7b158a7..d5b7ef3f 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: Gitee AI icon: giteeai.svg spec: + litellm_provider: openai config: - name: base_url label: @@ -22,6 +23,7 @@ spec: type: integer required: true default: 120 + alias: "gitee Gitee 码云 gitee-ai gitee ai serverless bge embedding rerank" support_type: - llm - text-embedding diff --git a/src/langbot/pkg/provider/modelmgr/requesters/groq.svg b/src/langbot/pkg/provider/modelmgr/requesters/groq.svg new file mode 100644 index 00000000..7c84ba68 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/groq.svg @@ -0,0 +1,4 @@ + + + Groq + diff --git a/src/langbot/pkg/provider/modelmgr/requesters/groqchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/groqchatcmpl.yaml new file mode 100644 index 00000000..d5136747 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/groqchatcmpl.yaml @@ -0,0 +1,29 @@ +apiVersion: v1 +kind: LLMAPIRequester +metadata: + name: groq-chat-completions + label: + en_US: Groq + zh_Hans: Groq + icon: groq.svg +spec: + litellm_provider: groq + config: + - name: base_url + label: + en_US: Base URL + zh_Hans: 基础 URL + type: string + required: true + default: https://api.groq.com/openai/v1 + - name: timeout + label: + en_US: Timeout + zh_Hans: 超时时间 + type: integer + required: true + default: 120 + alias: "groq Groq 高速 llama mixtral 推理加速 lpu" + support_type: + - llm + provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/iflytek.svg b/src/langbot/pkg/provider/modelmgr/requesters/iflytek.svg new file mode 100644 index 00000000..7498b149 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/iflytek.svg @@ -0,0 +1,5 @@ + + + iFlytek + Spark + diff --git a/src/langbot/pkg/provider/modelmgr/requesters/iflytekchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/iflytekchatcmpl.yaml new file mode 100644 index 00000000..decc2222 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/iflytekchatcmpl.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: LLMAPIRequester +metadata: + name: iflytek-chat-completions + label: + en_US: iFlytek Spark + zh_Hans: 讯飞星火 + icon: iflytek.svg +spec: + litellm_provider: openai + config: + - name: base_url + label: + en_US: Base URL + zh_Hans: 基础 URL + type: string + required: true + default: https://spark-api-open.xf-yun.com/v1 + - name: timeout + label: + en_US: Timeout + zh_Hans: 超时时间 + type: integer + required: true + default: 120 + alias: "iflytek 讯飞 科大讯飞 星火 spark xinghuo xunfei 讯飞星火" + support_type: + - llm + - text-embedding + - rerank + provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/jiekouaichatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/jiekouaichatcmpl.py deleted file mode 100644 index 305ae21f..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/jiekouaichatcmpl.py +++ /dev/null @@ -1,208 +0,0 @@ -from __future__ import annotations - -import openai -import typing - -from . import chatcmpl -from .. import requester -import openai.types.chat.chat_completion as chat_completion -import re -import langbot_plugin.api.entities.builtin.provider.message as provider_message -import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query -import langbot_plugin.api.entities.builtin.resource.tool as resource_tool - - -class JieKouAIChatCompletions(chatcmpl.OpenAIChatCompletions): - """接口 AI ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://api.jiekou.ai/openai', - 'timeout': 120, - } - - is_think: bool = False - - async def _make_msg( - self, - chat_completion: chat_completion.ChatCompletion, - remove_think: bool, - ) -> provider_message.Message: - chatcmpl_message = chat_completion.choices[0].message.model_dump() - # print(chatcmpl_message.keys(), chatcmpl_message.values()) - - # 确保 role 字段存在且不为 None - if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None: - chatcmpl_message['role'] = 'assistant' - - reasoning_content = chatcmpl_message['reasoning_content'] if 'reasoning_content' in chatcmpl_message else None - - # deepseek的reasoner模型 - chatcmpl_message['content'] = await self._process_thinking_content( - chatcmpl_message['content'], reasoning_content, remove_think - ) - - # 移除 reasoning_content 字段,避免传递给 Message - if 'reasoning_content' in chatcmpl_message: - del chatcmpl_message['reasoning_content'] - - message = provider_message.Message(**chatcmpl_message) - - return message - - async def _process_thinking_content( - self, - content: str, - reasoning_content: str = None, - remove_think: bool = False, - ) -> tuple[str, str]: - """处理思维链内容 - - Args: - content: 原始内容 - reasoning_content: reasoning_content 字段内容 - remove_think: 是否移除思维链 - - Returns: - 处理后的内容 - """ - if remove_think: - content = re.sub(r'.*?', '', content, flags=re.DOTALL) - else: - if reasoning_content is not None: - content = '\n' + reasoning_content + '\n\n' + content - return content - - async def _make_msg_chunk( - self, - delta: dict[str, typing.Any], - idx: int, - ) -> provider_message.MessageChunk: - # 处理流式chunk和完整响应的差异 - # print(chat_completion.choices[0]) - - # 确保 role 字段存在且不为 None - if 'role' not in delta or delta['role'] is None: - delta['role'] = 'assistant' - - reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None - - delta['content'] = '' if delta['content'] is None else delta['content'] - # print(reasoning_content) - - # deepseek的reasoner模型 - - if reasoning_content is not None: - delta['content'] += reasoning_content - - message = provider_message.MessageChunk(**delta) - - return message - - async def _closure_stream( - self, - query: pipeline_query.Query, - req_messages: list[dict], - use_model: requester.RuntimeLLMModel, - use_funcs: list[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]: - self.client.api_key = use_model.provider.token_mgr.get_token() - - args = {} - args['model'] = use_model.model_entity.name - - if use_funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) - - if tools: - args['tools'] = tools - - # 设置此次请求中的messages - messages = req_messages.copy() - - # 检查vision - for msg in messages: - if 'content' in msg and isinstance(msg['content'], list): - for me in msg['content']: - if me['type'] == 'image_base64': - me['image_url'] = {'url': me['image_base64']} - me['type'] = 'image_url' - del me['image_base64'] - - args['messages'] = messages - args['stream'] = True - - # tool_calls_map: dict[str, provider_message.ToolCall] = {} - chunk_idx = 0 - thinking_started = False - thinking_ended = False - role = 'assistant' # 默认角色 - async for chunk in self._req_stream(args, extra_body=extra_args): - # 解析 chunk 数据 - if hasattr(chunk, 'choices') and chunk.choices: - choice = chunk.choices[0] - delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {} - finish_reason = getattr(choice, 'finish_reason', None) - else: - delta = {} - finish_reason = None - - # 从第一个 chunk 获取 role,后续使用这个 role - if 'role' in delta and delta['role']: - role = delta['role'] - - # 获取增量内容 - delta_content = delta.get('content', '') - # reasoning_content = delta.get('reasoning_content', '') - - if remove_think: - if delta['content'] is not None: - if '' in delta['content'] and not thinking_started and not thinking_ended: - thinking_started = True - continue - elif delta['content'] == r'' and not thinking_ended: - thinking_ended = True - continue - elif thinking_ended and delta['content'] == '\n\n' and thinking_started: - thinking_started = False - continue - elif thinking_started and not thinking_ended: - continue - - # delta_tool_calls = None - if delta.get('tool_calls'): - for tool_call in delta['tool_calls']: - if tool_call['id'] and tool_call['function']['name']: - tool_id = tool_call['id'] - tool_name = tool_call['function']['name'] - - if tool_call['id'] is None: - tool_call['id'] = tool_id - if tool_call['function']['name'] is None: - tool_call['function']['name'] = tool_name - if tool_call['function']['arguments'] is None: - tool_call['function']['arguments'] = '' - if tool_call['type'] is None: - tool_call['type'] = 'function' - - # 跳过空的第一个 chunk(只有 role 没有内容) - if chunk_idx == 0 and not delta_content and not delta.get('tool_calls'): - chunk_idx += 1 - continue - - # 构建 MessageChunk - 只包含增量内容 - chunk_data = { - 'role': role, - 'content': delta_content if delta_content else None, - 'tool_calls': delta.get('tool_calls'), - 'is_final': bool(finish_reason), - } - - # 移除 None 值 - chunk_data = {k: v for k, v in chunk_data.items() if v is not None} - - yield provider_message.MessageChunk(**chunk_data) - chunk_idx += 1 diff --git a/src/langbot/pkg/provider/modelmgr/requesters/jiekouaichatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/jiekouaichatcmpl.yaml index 3c791d73..44aa0774 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/jiekouaichatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/jiekouaichatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: 接口 AI icon: jiekouai.png spec: + litellm_provider: openai config: - name: base_url label: @@ -29,9 +30,11 @@ spec: type: int required: true default: 120 + alias: "jiekouai 接口AI 接口 jiekou ai 中转 中转站 aggregator" support_type: - llm - text-embedding + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/jinarerank.yaml b/src/langbot/pkg/provider/modelmgr/requesters/jinarerank.yaml index 3b448e38..b94b2f74 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/jinarerank.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/jinarerank.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: Jina icon: jina.svg spec: + litellm_provider: openai config: - name: base_url label: @@ -22,6 +23,7 @@ spec: type: integer required: true default: 120 + alias: "jina Jina jina-ai jinaai rerank 重排 reranker jina-reranker embedding" support_type: - rerank provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py new file mode 100644 index 00000000..6b087916 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py @@ -0,0 +1,733 @@ +"""LiteLLM unified requester for chat, embedding, and rerank.""" + +from __future__ import annotations + +import typing + +import litellm +from litellm import acompletion, aembedding, arerank + +from .. import errors, requester +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.provider.message as provider_message + + +class LiteLLMRequester(requester.ProviderAPIRequester): + """LiteLLM unified API requester supporting chat, embedding, and rerank.""" + + _EMBEDDING_MODEL_HINTS = ('embedding', 'embed', 'bge-', 'e5-', 'm3e', 'gte-', 'text-embedding') + _RERANK_MODEL_HINTS = ('rerank', 're-rank', 're_rank') + + default_config: dict[str, typing.Any] = { + 'base_url': '', + 'timeout': 120, + 'custom_llm_provider': '', + 'drop_params': False, + 'num_retries': 0, + 'api_version': '', + } + + async def initialize(self): + """Initialize LiteLLM client settings.""" + # LiteLLM doesn't require explicit client initialization + # Configuration is passed per-request via litellm params + pass + + def _build_litellm_model_name(self, model_name: str, custom_llm_provider: str | None = None) -> str: + """Build LiteLLM model name with provider prefix if needed.""" + provider = custom_llm_provider or self.requester_cfg.get('custom_llm_provider', '') + if provider: + # LiteLLM format: provider/model_name + if model_name.startswith(f'{provider}/'): + return model_name + return f'{provider}/{model_name}' + # If no custom provider, assume model_name already includes prefix or is OpenAI-compatible + return model_name + + def _get_custom_llm_provider(self) -> str | None: + return self.requester_cfg.get('custom_llm_provider') or None + + def _safe_litellm_bool_helper(self, helper_name: str, model_name: str) -> bool: + """Call a LiteLLM boolean capability helper without letting metadata gaps fail requests.""" + helper = getattr(litellm, helper_name, None) + if not callable(helper): + return False + + provider = self._get_custom_llm_provider() + candidates: list[tuple[str, str | None]] = [(model_name, provider)] + litellm_model_name = self._build_litellm_model_name(model_name) + if litellm_model_name != model_name: + candidates.append((litellm_model_name, None)) + for metadata_provider in self._metadata_provider_candidates(model_name): + candidates.append((f'{metadata_provider}/{model_name}', None)) + + tried_candidates: set[tuple[str, str | None]] = set() + for candidate_model, candidate_provider in candidates: + candidate_key = (candidate_model, candidate_provider) + if candidate_key in tried_candidates: + continue + tried_candidates.add(candidate_key) + try: + if bool(helper(model=candidate_model, custom_llm_provider=candidate_provider)): + return True + except Exception: + continue + return False + + def _context_length_from_scan_payload(self, model_payload: dict[str, typing.Any] | None) -> int | None: + if not model_payload: + return None + + for field_name in ('context_length', 'context_window', 'max_context_length'): + value = model_payload.get(field_name) + if isinstance(value, bool): + continue + if isinstance(value, int) and value > 0: + return value + if isinstance(value, str) and value.isdigit(): + parsed_value = int(value) + if parsed_value > 0: + return parsed_value + return None + + def _metadata_provider_candidates(self, model_name: str) -> list[str]: + normalized_model_name = (model_name or '').lower() + candidates = [] + if normalized_model_name.startswith(('moonshot-', 'kimi-')): + candidates.append('moonshot') + if normalized_model_name.startswith('deepseek-'): + candidates.append('deepseek') + + base_url = self.requester_cfg.get('base_url', '').lower() + if 'moonshot' in base_url: + candidates.append('moonshot') + if 'deepseek' in base_url: + candidates.append('deepseek') + + deduped_candidates = [] + for candidate in candidates: + if candidate not in deduped_candidates: + deduped_candidates.append(candidate) + return deduped_candidates + + def _known_context_length_fallback(self, model_name: str) -> int | None: + normalized_model_name = (model_name or '').lower() + if normalized_model_name.startswith('deepseek-v4-'): + return 1_000_000 + if normalized_model_name.startswith(('kimi-k2.5', 'kimi-k2.6')): + return 256 * 1024 + if normalized_model_name.startswith('moonshot-v1-8k'): + return 8 * 1024 + if normalized_model_name.startswith('moonshot-v1-32k'): + return 32 * 1024 + if normalized_model_name.startswith('moonshot-v1-128k') or normalized_model_name == 'moonshot-v1-auto': + return 128 * 1024 + return None + + def _safe_context_length(self, model_name: str) -> int | None: + helper = getattr(litellm, 'get_max_tokens', None) + if not callable(helper): + return self._known_context_length_fallback(model_name) + + candidates = [model_name] + litellm_model_name = self._build_litellm_model_name(model_name) + if litellm_model_name != model_name: + candidates.append(litellm_model_name) + for provider in self._metadata_provider_candidates(model_name): + candidates.append(f'{provider}/{model_name}') + + tried_candidates = [] + for candidate in candidates: + if candidate in tried_candidates: + continue + tried_candidates.append(candidate) + try: + max_tokens = helper(candidate) + except Exception: + continue + if isinstance(max_tokens, int) and max_tokens > 0: + return max_tokens + return self._known_context_length_fallback(model_name) + + def _supports_function_calling(self, model_name: str) -> bool: + return self._safe_litellm_bool_helper('supports_function_calling', model_name) + + def _supports_vision(self, model_name: str) -> bool: + return self._safe_litellm_bool_helper('supports_vision', model_name) + + def _infer_model_type(self, model_id: str) -> str: + normalized_id = (model_id or '').lower() + if any(kw in normalized_id for kw in self._RERANK_MODEL_HINTS): + return 'rerank' + if any(kw in normalized_id for kw in self._EMBEDDING_MODEL_HINTS): + return 'embedding' + return 'llm' + + def _enrich_scanned_model( + self, + model_id: str, + model_payload: dict[str, typing.Any] | None = None, + ) -> dict[str, typing.Any]: + model_type = self._infer_model_type(model_id) + scanned_model: dict[str, typing.Any] = { + 'id': model_id, + 'name': model_id, + 'type': model_type, + } + + if model_type == 'llm': + abilities = [] + if self._supports_function_calling(model_id): + abilities.append('func_call') + supports_provider_reported_vision = bool( + model_payload + and (model_payload.get('supports_image_in') is True or model_payload.get('supports_vision') is True) + ) + if supports_provider_reported_vision or self._supports_vision(model_id): + abilities.append('vision') + scanned_model['abilities'] = abilities + + context_length = self._context_length_from_scan_payload(model_payload) + if context_length is None: + context_length = self._safe_context_length(model_id) + if context_length is not None: + scanned_model['context_length'] = context_length + + return scanned_model + + def _convert_messages(self, messages: typing.List[provider_message.Message]) -> list[dict]: + """Convert LangBot messages to LiteLLM/OpenAI format.""" + req_messages = [] + for m in messages: + msg_dict = m.dict(exclude_none=True) + content = msg_dict.get('content') + + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get('type') == 'image_base64': + part['image_url'] = {'url': part['image_base64']} + part['type'] = 'image_url' + del part['image_base64'] + + req_messages.append(msg_dict) + + return req_messages + + def _process_thinking_content(self, content: str, reasoning_content: str | None, remove_think: bool) -> str: + """Process thinking/reasoning content. + + Args: + content: The main content from response + reasoning_content: Separate reasoning content from model + remove_think: If True, remove thinking markers; if False, preserve them + + Returns: + Processed content string + """ + # Extract and handle thinking tags + if content and 'CRETIRE_REASONING_BEGINk' in content and 'CRETIRE_REASONING_ENDk' in content: + import re + + think_pattern = r'CRETIRE_REASONING_BEGINk(.*?)CRETIRE_REASONING_ENDk' + + if remove_think: + # Remove thinking tags and their content from output + content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip() + # else: preserve thinking content as-is + + # Handle separate reasoning_content field + # Currently we don't include reasoning_content in user-facing output regardless of remove_think + # because it's typically internal model reasoning, not user-visible thinking + return content or '' + + @staticmethod + def _normalize_usage(usage: typing.Any) -> dict: + """Normalize a LiteLLM/OpenAI usage object into a plain token dict. + + Handles several real-world shapes returned by different upstreams: + - object with ``prompt_tokens`` / ``completion_tokens`` / ``total_tokens`` attrs + - dict with the same keys + - missing ``total_tokens`` (derived from prompt + completion) + - ``None`` / partially-populated usage (defaults to 0) + """ + if usage is None: + return {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0} + + def _get(key: str) -> typing.Any: + if isinstance(usage, dict): + return usage.get(key) + return getattr(usage, key, None) + + prompt_tokens = _get('prompt_tokens') or 0 + completion_tokens = _get('completion_tokens') or 0 + total_tokens = _get('total_tokens') or 0 + + # Some providers omit total_tokens in streaming usage; derive it. + if not total_tokens: + total_tokens = prompt_tokens + completion_tokens + + return { + 'prompt_tokens': int(prompt_tokens), + 'completion_tokens': int(completion_tokens), + 'total_tokens': int(total_tokens), + } + + def _extract_usage(self, response) -> dict: + """Extract usage info from a non-streaming LiteLLM response.""" + return self._normalize_usage(getattr(response, 'usage', None)) + + @staticmethod + def _as_dict(value: typing.Any) -> dict: + if value is None: + return {} + if isinstance(value, dict): + return value + if hasattr(value, 'model_dump'): + return value.model_dump() + return {} + + def _normalize_stream_tool_calls( + self, + raw_tool_calls: typing.Any, + tool_call_state: dict[int, dict[str, str]], + ) -> list[dict] | None: + """Fill OpenAI-style streaming tool-call deltas so MessageChunk can validate them.""" + if not raw_tool_calls: + return None + + normalized = [] + for fallback_index, raw_tool_call in enumerate(raw_tool_calls): + tool_call = self._as_dict(raw_tool_call) + index = tool_call.get('index') + if not isinstance(index, int): + index = fallback_index + + state = tool_call_state.setdefault(index, {'id': '', 'type': 'function', 'name': ''}) + if tool_call.get('id'): + state['id'] = tool_call['id'] + if tool_call.get('type'): + state['type'] = tool_call['type'] + + function = self._as_dict(tool_call.get('function')) + if function.get('name'): + state['name'] = function['name'] + + arguments = function.get('arguments') + if arguments is None: + arguments = '' + elif not isinstance(arguments, str): + arguments = str(arguments) + + if not state['id'] or not state['name']: + continue + + normalized.append( + { + 'id': state['id'], + 'type': state['type'] or 'function', + 'function': { + 'name': state['name'], + 'arguments': arguments, + }, + } + ) + + return normalized or None + + def _build_common_args(self, args: dict, include_retry_params: bool = True) -> dict: + """Apply common requester config to args dict.""" + if self.requester_cfg.get('base_url'): + args['api_base'] = self.requester_cfg['base_url'] + if self.requester_cfg.get('timeout'): + args['timeout'] = self.requester_cfg['timeout'] + if include_retry_params: + if self.requester_cfg.get('drop_params'): + args['drop_params'] = self.requester_cfg['drop_params'] + if self.requester_cfg.get('num_retries'): + args['num_retries'] = self.requester_cfg['num_retries'] + if self.requester_cfg.get('api_version'): + args['api_version'] = self.requester_cfg['api_version'] + return args + + def _handle_litellm_error(self, e: Exception) -> None: + """Convert LiteLLM exceptions to RequesterError. Never returns, always raises.""" + # Check more specific exceptions first (they inherit from base exceptions) + if isinstance(e, litellm.ContextWindowExceededError): + raise errors.RequesterError(f'上下文长度超限: {str(e)}') + if isinstance(e, litellm.BadRequestError): + raise errors.RequesterError(f'请求参数错误: {str(e)}') + if isinstance(e, litellm.AuthenticationError): + raise errors.RequesterError(f'API key 无效: {str(e)}') + if isinstance(e, litellm.NotFoundError): + raise errors.RequesterError(f'模型或路径无效: {str(e)}') + if isinstance(e, litellm.RateLimitError): + raise errors.RequesterError(f'请求过于频繁或余额不足: {str(e)}') + if isinstance(e, litellm.Timeout): + raise errors.RequesterError(f'请求超时: {str(e)}') + if isinstance(e, litellm.APIConnectionError): + raise errors.RequesterError(f'连接错误: {str(e)}') + if isinstance(e, litellm.APIError): + raise errors.RequesterError(f'API 错误: {str(e)}') + raise errors.RequesterError(f'未知错误: {str(e)}') + + async def _build_completion_args( + self, + model: requester.RuntimeLLMModel, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, + extra_args: dict[str, typing.Any] = {}, + stream: bool = False, + ) -> dict: + """Build common completion arguments for invoke_llm and invoke_llm_stream.""" + req_messages = self._convert_messages(messages) + model_name = self._build_litellm_model_name(model.model_entity.name) + api_key = model.provider.token_mgr.get_token() + + args = { + 'model': model_name, + 'messages': req_messages, + 'api_key': api_key, + } + if stream: + args['stream'] = True + args['stream_options'] = {'include_usage': True} + self._build_common_args(args) + + # Apply model-level extra_args first, then call-level extra_args + if model.model_entity.extra_args: + args.update(model.model_entity.extra_args) + args.update(extra_args) + + if funcs: + tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs) + if tools: + args['tools'] = tools + args.setdefault('tool_choice', 'auto') + + return args + + async def invoke_llm( + self, + query: pipeline_query.Query, + model: requester.RuntimeLLMModel, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, + extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, + ) -> tuple[provider_message.Message, dict]: + """Invoke LLM and return message with usage info.""" + args = await self._build_completion_args(model, messages, funcs, extra_args, stream=False) + + try: + response = await acompletion(**args) + + message_data = response.choices[0].message.model_dump() + if 'role' not in message_data or message_data['role'] is None: + message_data['role'] = 'assistant' + + content = message_data.get('content', '') + reasoning_content = message_data.get('reasoning_content', None) + message_data['content'] = self._process_thinking_content(content, reasoning_content, remove_think) + + if 'reasoning_content' in message_data: + del message_data['reasoning_content'] + + message = provider_message.Message(**message_data) + usage_info = self._extract_usage(response) + + return message, usage_info + + except Exception as e: + self._handle_litellm_error(e) + + async def invoke_llm_stream( + self, + query: pipeline_query.Query, + model: requester.RuntimeLLMModel, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, + extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, + ) -> provider_message.MessageChunk: + """Invoke LLM streaming and yield chunks.""" + args = await self._build_completion_args(model, messages, funcs, extra_args, stream=True) + + chunk_idx = 0 + role = 'assistant' + tool_call_state: dict[int, dict[str, str]] = {} + + try: + response = await acompletion(**args) + async for chunk in response: + # Capture usage whenever a chunk carries it. + # + # Important: many OpenAI-compatible gateways (e.g. new-api) and + # providers send the final usage payload in a chunk that STILL + # contains a (empty-delta) choice, not an empty `choices` list. + # The previous implementation only captured usage when `choices` + # was empty, so streamed calls always recorded 0 tokens. + # We therefore capture usage independently of `choices`, and then + # fall through to also process any content this chunk may carry. + if getattr(chunk, 'usage', None): + usage_info = self._normalize_usage(chunk.usage) + if query is not None: + if query.variables is None: + query.variables = {} + query.variables['_stream_usage'] = usage_info + + if not hasattr(chunk, 'choices') or not chunk.choices: + continue + + choice = chunk.choices[0] + delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {} + finish_reason = getattr(choice, 'finish_reason', None) + + if 'role' in delta and delta['role']: + role = delta['role'] + + delta_content = delta.get('content', '') + reasoning_content = delta.get('reasoning_content', '') + + # Handle reasoning_content based on remove_think flag + if reasoning_content: + if remove_think: + # Skip reasoning content when remove_think is True + chunk_idx += 1 + continue + else: + # Use reasoning_content as the displayed content + delta_content = reasoning_content + + tool_calls = self._normalize_stream_tool_calls(delta.get('tool_calls'), tool_call_state) + + if chunk_idx == 0 and not delta_content and not tool_calls: + chunk_idx += 1 + continue + + chunk_data = { + 'role': role, + 'content': delta_content if delta_content else None, + 'tool_calls': tool_calls, + 'is_final': bool(finish_reason), + } + + chunk_data = {k: v for k, v in chunk_data.items() if v is not None} + yield provider_message.MessageChunk(**chunk_data) + chunk_idx += 1 + + except Exception as e: + self._handle_litellm_error(e) + + async def invoke_embedding( + self, + model: requester.RuntimeEmbeddingModel, + input_text: list[str], + extra_args: dict[str, typing.Any] = {}, + ) -> tuple[list[list[float]], dict]: + """Invoke embedding and return vectors with usage info.""" + model_name = self._build_litellm_model_name(model.model_entity.name) + api_key = model.provider.token_mgr.get_token() + + args = { + 'model': model_name, + 'input': input_text, + 'api_key': api_key, + } + self._build_common_args(args, include_retry_params=False) + + if model.model_entity.extra_args: + args.update(model.model_entity.extra_args) + + args.update(extra_args) + + try: + response = await aembedding(**args) + + # LiteLLM returns response.data entries either as objects with an + # `.embedding` attribute or as plain dicts (many OpenAI-compatible + # gateways, e.g. new-api, yield dict-shaped entries). Handle both. + embeddings = [d['embedding'] if isinstance(d, dict) else d.embedding for d in response.data] + usage_info = self._extract_usage(response) + + return embeddings, usage_info + + except Exception as e: + self._handle_litellm_error(e) + + async def invoke_rerank( + self, + model: requester.RuntimeRerankModel, + query: str, + documents: typing.List[str], + extra_args: dict[str, typing.Any] = {}, + ) -> typing.List[dict]: + """Invoke rerank and return relevance scores.""" + model_name = self._build_litellm_model_name(model.model_entity.name) + api_key = model.provider.token_mgr.get_token() + + top_n = min(len(documents), 64) + + provider = self._get_custom_llm_provider() + + try: + # LiteLLM's rerank API does not support the `openai` provider + # (litellm/rerank_api/main.py raises "Unsupported provider: openai"). + # OpenAI-compatible gateways (newapi / one-api / vLLM / Xinference, etc.) + # expose the standard Jina/Cohere-style POST /v1/rerank endpoint, so + # call it directly over HTTP for openai-compatible (or unspecified) providers. + if provider in (None, '', 'openai'): + results = await self._invoke_rerank_openai_compatible( + model_name=model.model_entity.name, + query=query, + documents=documents, + api_key=api_key, + top_n=top_n, + extra_args={**(model.model_entity.extra_args or {}), **extra_args}, + ) + else: + args = { + 'model': model_name, + 'query': query, + 'documents': documents, + 'api_key': api_key, + 'top_n': top_n, + } + self._build_common_args(args, include_retry_params=False) + + if model.model_entity.extra_args: + args.update(model.model_entity.extra_args) + + args.update(extra_args) + + response = await arerank(**args) + + results = [] + for r in response.results: + results.append( + { + 'index': r.get('index', 0), + 'relevance_score': r.get('relevance_score', 0.0), + } + ) + + if results: + scores = [r['relevance_score'] for r in results] + min_score = min(scores) + max_score = max(scores) + if max_score - min_score > 1e-6: + for r in results: + r['relevance_score'] = (r['relevance_score'] - min_score) / (max_score - min_score) + + return results + + except errors.RequesterError: + raise + except Exception as e: + self._handle_litellm_error(e) + + async def _invoke_rerank_openai_compatible( + self, + model_name: str, + query: str, + documents: typing.List[str], + api_key: str, + top_n: int, + extra_args: dict[str, typing.Any] = {}, + ) -> typing.List[dict]: + """Call the standard Jina/Cohere-style POST /v1/rerank endpoint over HTTP. + + Used for OpenAI-compatible gateways where litellm.arerank rejects the + `openai` provider. Returns the same shape as the litellm path: + a list of {'index': int, 'relevance_score': float}. + """ + import httpx + + base_url = (self.requester_cfg.get('base_url') or '').rstrip('/') + if not base_url: + raise errors.RequesterError('Base URL required for rerank') + + timeout = self.requester_cfg.get('timeout', 120) + + headers = {'Content-Type': 'application/json'} + if api_key: + headers['Authorization'] = f'Bearer {api_key}' + + payload: dict[str, typing.Any] = { + 'model': model_name, + 'query': query, + 'documents': documents, + 'top_n': top_n, + } + if extra_args: + payload.update(extra_args) + + rerank_url = f'{base_url}/rerank' + + try: + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post(rerank_url, headers=headers, json=payload) + resp.raise_for_status() + data = resp.json() + except httpx.HTTPStatusError as e: + body = '' + try: + body = e.response.text + except Exception: + pass + raise errors.RequesterError(f'rerank 请求失败 (HTTP {e.response.status_code}): {body or str(e)}') + except httpx.HTTPError as e: + raise errors.RequesterError(f'rerank 连接错误: {str(e)}') + + raw_results = data.get('results', []) if isinstance(data, dict) else [] + results = [] + for r in raw_results: + results.append( + { + 'index': r.get('index', 0), + 'relevance_score': r.get('relevance_score', r.get('score', 0.0)) or 0.0, + } + ) + + return results + + async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]: + """Scan models supported by the provider.""" + import httpx + + base_url = self.requester_cfg.get('base_url', '').rstrip('/') + timeout = self.requester_cfg.get('timeout', 120) + + if not base_url: + raise errors.RequesterError('Base URL required for model scanning') + + headers = {} + if api_key: + headers['Authorization'] = f'Bearer {api_key}' + + models_url = f'{base_url}/models' + + try: + async with httpx.AsyncClient(trust_env=True, timeout=timeout) as client: + response = await client.get(models_url, headers=headers) + response.raise_for_status() + payload = response.json() + + models = [] + for item in payload.get('data', []): + model_id = item.get('id') + if not model_id: + continue + + models.append(self._enrich_scanned_model(model_id, item)) + + models.sort(key=lambda x: (x['type'] != 'llm', x['name'].lower())) + + return {'models': models} + + except httpx.HTTPStatusError as e: + raise errors.RequesterError(f'Model scan failed: {e.response.status_code}') + except httpx.TimeoutException: + raise errors.RequesterError('Model scan timeout') + except Exception as e: + raise errors.RequesterError(f'Model scan error: {str(e)}') diff --git a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.yaml b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.yaml new file mode 100644 index 00000000..1d5452d5 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.yaml @@ -0,0 +1,65 @@ +apiVersion: v1 +kind: LLMAPIRequester +metadata: + name: litellm-chat + label: + en_US: LiteLLM (Unified) + zh_Hans: LiteLLM (统一请求器) + icon: litellm.svg +spec: + config: + - name: base_url + label: + en_US: Base URL + zh_Hans: 基础 URL + type: string + required: false + default: '' + - name: timeout + label: + en_US: Timeout + zh_Hans: 超时时间 + type: integer + required: true + default: 120 + - name: custom_llm_provider + label: + en_US: Custom Provider + zh_Hans: 自定义 Provider + type: string + required: false + default: '' + description: + en_US: Force provider type (e.g., anthropic, openai, gemini) + zh_Hans: 强制指定 provider 类型(如 anthropic, openai, gemini) + - name: drop_params + label: + en_US: Drop Unsupported Params + zh_Hans: 丢弃不支持参数 + type: boolean + required: false + default: false + - name: num_retries + label: + en_US: Number of Retries + zh_Hans: 重试次数 + type: integer + required: false + default: 0 + - name: api_version + label: + en_US: API Version + zh_Hans: API 版本 + type: string + required: false + default: '' + alias: "litellm LiteLLM 通用 universal 万能 兼容 compatible proxy 代理 中转" + support_type: + - llm + - text-embedding + - rerank + provider_category: unified +execution: + python: + path: ./litellmchat.py + attr: LiteLLMRequester \ No newline at end of file diff --git a/src/langbot/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py deleted file mode 100644 index c9060c1b..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import typing -import openai - -from . import chatcmpl - - -class LmStudioChatCompletions(chatcmpl.OpenAIChatCompletions): - """LMStudio ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'http://127.0.0.1:1234/v1', - 'timeout': 120, - } diff --git a/src/langbot/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml index 81dc82cf..c1d3ad15 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: LM Studio icon: lmstudio.webp spec: + litellm_provider: openai config: - name: base_url label: @@ -22,6 +23,7 @@ spec: type: integer required: true default: 120 + alias: "lmstudio LM Studio lm-studio 本地 local 本地部署 self-hosted gguf" support_type: - llm - text-embedding diff --git a/src/langbot/pkg/provider/modelmgr/requesters/mimo.svg b/src/langbot/pkg/provider/modelmgr/requesters/mimo.svg new file mode 100644 index 00000000..5d9b21dc --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/mimo.svg @@ -0,0 +1,4 @@ + + + MiMo + diff --git a/src/langbot/pkg/provider/modelmgr/requesters/mimochatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/mimochatcmpl.yaml new file mode 100644 index 00000000..e20f95c8 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/mimochatcmpl.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: LLMAPIRequester +metadata: + name: mimo-chat-completions + label: + en_US: Xiaomi MiMo + zh_Hans: 小米 MiMo + icon: mimo.svg +spec: + litellm_provider: openai + config: + - name: base_url + label: + en_US: Base URL + zh_Hans: 基础 URL + type: string + required: true + default: https://api.xiaomimimo.com/v1 + - name: timeout + label: + en_US: Timeout + zh_Hans: 超时时间 + type: integer + required: true + default: 120 + alias: "mimo MiMo 小米 xiaomi 小米大模型 xiaomi-mimo" + support_type: + - llm + - text-embedding + - rerank + provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/minimax.svg b/src/langbot/pkg/provider/modelmgr/requesters/minimax.svg new file mode 100644 index 00000000..1afeadc3 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/minimax.svg @@ -0,0 +1,4 @@ + + + MiniMax + diff --git a/src/langbot/pkg/provider/modelmgr/requesters/minimaxchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/minimaxchatcmpl.yaml new file mode 100644 index 00000000..b0c246c9 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/minimaxchatcmpl.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: LLMAPIRequester +metadata: + name: minimax-chat-completions + label: + en_US: MiniMax + zh_Hans: MiniMax + icon: minimax.svg +spec: + litellm_provider: openai + config: + - name: base_url + label: + en_US: Base URL + zh_Hans: 基础 URL + type: string + required: true + default: https://api.minimax.chat/v1 + - name: timeout + label: + en_US: Timeout + zh_Hans: 超时时间 + type: integer + required: true + default: 120 + alias: "minimax MiniMax 名之梦 海螺 hailuo abab embo embedding" + support_type: + - llm + - text-embedding + - rerank + provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/mistral.svg b/src/langbot/pkg/provider/modelmgr/requesters/mistral.svg new file mode 100644 index 00000000..853022d9 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/mistral.svg @@ -0,0 +1,5 @@ + + + Mistral + AI + diff --git a/src/langbot/pkg/provider/modelmgr/requesters/mistralchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/mistralchatcmpl.yaml new file mode 100644 index 00000000..1ad13686 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/mistralchatcmpl.yaml @@ -0,0 +1,30 @@ +apiVersion: v1 +kind: LLMAPIRequester +metadata: + name: mistral-chat-completions + label: + en_US: Mistral AI + zh_Hans: Mistral AI + icon: mistral.svg +spec: + litellm_provider: mistral + config: + - name: base_url + label: + en_US: Base URL + zh_Hans: 基础 URL + type: string + required: true + default: https://api.mistral.ai/v1 + - name: timeout + label: + en_US: Timeout + zh_Hans: 超时时间 + type: integer + required: true + default: 120 + alias: "mistral Mistral 米斯特拉尔 mixtral codestral mistral-embed le-chat" + support_type: + - llm + - text-embedding + provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py deleted file mode 100644 index c98a71d7..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py +++ /dev/null @@ -1,561 +0,0 @@ -from __future__ import annotations - -import asyncio -import typing - -import openai -import openai.types.chat.chat_completion as chat_completion -import httpx - -from .. import entities, errors, requester -import langbot_plugin.api.entities.builtin.resource.tool as resource_tool -import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query -import langbot_plugin.api.entities.builtin.provider.message as provider_message - - -class ModelScopeChatCompletions(requester.ProviderAPIRequester): - """ModelScope ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://api-inference.modelscope.cn/v1', - 'timeout': 120, - } - - async def initialize(self): - self.client = openai.AsyncClient( - api_key=self.init_api_key, - base_url=self.requester_cfg['base_url'], - timeout=self.requester_cfg['timeout'], - http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']), - ) - - def _mask_api_key(self, api_key: str | None) -> str: - if not api_key: - return '' - if len(api_key) <= 8: - return '****' - return f'{api_key[:4]}...{api_key[-4:]}' - - def _infer_model_type(self, model_id: str) -> str: - normalized_model_id = (model_id or '').lower() - embedding_keywords = ( - 'embedding', - 'embed', - 'bge-', - 'e5-', - 'm3e', - 'gte-', - 'multilingual-e5', - 'text-embedding', - ) - return 'embedding' if any(keyword in normalized_model_id for keyword in embedding_keywords) else 'llm' - - def _infer_model_abilities(self, item: dict[str, typing.Any], model_id: str) -> list[str]: - normalized_model_id = (model_id or '').lower() - abilities: set[str] = set() - - def _flatten(value: typing.Any) -> list[str]: - if value is None: - return [] - if isinstance(value, str): - return [value.lower()] - if isinstance(value, dict): - flattened: list[str] = [] - for nested_value in value.values(): - flattened.extend(_flatten(nested_value)) - return flattened - if isinstance(value, (list, tuple, set)): - flattened: list[str] = [] - for nested_value in value: - flattened.extend(_flatten(nested_value)) - return flattened - return [str(value).lower()] - - capability_tokens = _flatten(item.get('capabilities')) - capability_tokens.extend(_flatten(item.get('modalities'))) - capability_tokens.extend(_flatten(item.get('input_modalities'))) - capability_tokens.extend(_flatten(item.get('output_modalities'))) - capability_tokens.extend(_flatten(item.get('supported_generation_methods'))) - capability_tokens.extend(_flatten(item.get('supported_parameters'))) - capability_tokens.extend(_flatten(item.get('architecture'))) - - combined_tokens = capability_tokens + [normalized_model_id] - - vision_keywords = ('vision', 'image', 'file', 'video', 'multimodal', 'vl', 'ocr', 'omni') - function_call_keywords = ('function', 'tool', 'tools', 'tool_choice', 'tool_call', 'tool-use', 'tool_use') - - if any(any(keyword in token for keyword in vision_keywords) for token in combined_tokens): - abilities.add('vision') - - if any(any(keyword in token for keyword in function_call_keywords) for token in combined_tokens): - abilities.add('func_call') - - return sorted(abilities) - - def _normalize_modalities(self, value: typing.Any) -> list[str]: - normalized: list[str] = [] - - def _collect(item: typing.Any): - if item is None: - return - if isinstance(item, str): - for part in item.replace('->', ',').replace('+', ',').split(','): - token = part.strip().lower() - if token and token not in normalized: - normalized.append(token) - return - if isinstance(item, dict): - for nested in item.values(): - _collect(nested) - return - if isinstance(item, (list, tuple, set)): - for nested in item: - _collect(nested) - return - - _collect(value) - return normalized - - def _extract_scan_metadata(self, item: dict[str, typing.Any], model_id: str) -> dict[str, typing.Any]: - display_name = item.get('name') - if not isinstance(display_name, str) or not display_name.strip() or display_name == model_id: - display_name = '' - - description = item.get('description') - if not isinstance(description, str) or not description.strip(): - description = '' - - context_length = item.get('context_length') - if context_length is None and isinstance(item.get('top_provider'), dict): - context_length = item['top_provider'].get('context_length') - - if not isinstance(context_length, int): - try: - context_length = int(context_length) if context_length is not None else None - except (TypeError, ValueError): - context_length = None - - input_modalities = self._normalize_modalities(item.get('input_modalities')) - output_modalities = self._normalize_modalities(item.get('output_modalities')) - - if isinstance(item.get('architecture'), dict): - if not input_modalities: - input_modalities = self._normalize_modalities(item['architecture'].get('input_modalities')) - if not output_modalities: - output_modalities = self._normalize_modalities(item['architecture'].get('output_modalities')) - - owned_by = item.get('owned_by') - if not isinstance(owned_by, str) or not owned_by.strip(): - owned_by = '' - - return { - 'display_name': display_name or None, - 'description': description or None, - 'context_length': context_length, - 'owned_by': owned_by or None, - 'input_modalities': input_modalities, - 'output_modalities': output_modalities, - } - - async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]: - headers = {} - if api_key: - headers['Authorization'] = f'Bearer {api_key}' - - models_url = f'{self.requester_cfg["base_url"].rstrip("/")}/models' - async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client: - response = await client.get(models_url, headers=headers) - response.raise_for_status() - payload = response.json() - - models = [] - for item in payload.get('data', []): - model_id = item.get('id') - if not model_id: - continue - models.append( - { - 'id': model_id, - 'name': model_id, - 'type': self._infer_model_type(model_id), - 'abilities': self._infer_model_abilities(item, model_id), - **self._extract_scan_metadata(item, model_id), - } - ) - - models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower())) - return { - 'models': models, - 'debug': { - 'request': { - 'method': 'GET', - 'url': models_url, - 'headers': { - 'Authorization': f'Bearer {self._mask_api_key(api_key)}' if api_key else '', - }, - }, - 'response': payload, - }, - } - - async def _req( - self, - query: pipeline_query.Query, - args: dict, - extra_body: dict = {}, - remove_think: bool = False, - ) -> list[dict[str, typing.Any]]: - args['stream'] = True - - chunk = None - - pending_content = '' - - tool_calls = [] - - resp_gen: openai.AsyncStream = await self.client.chat.completions.create(**args, extra_body=extra_body) - - chunk_idx = 0 - thinking_started = False - thinking_ended = False - tool_id = '' - tool_name = '' - message_delta = {} - async for chunk in resp_gen: - if not chunk or not chunk.id or not chunk.choices or not chunk.choices[0] or not chunk.choices[0].delta: - continue - - delta = chunk.choices[0].delta.model_dump() if hasattr(chunk.choices[0], 'delta') else {} - reasoning_content = delta.get('reasoning_content') - # 处理 reasoning_content - if reasoning_content: - # accumulated_reasoning += reasoning_content - # 如果设置了 remove_think,跳过 reasoning_content - if remove_think: - chunk_idx += 1 - continue - - # 第一次出现 reasoning_content,添加 开始标签 - if not thinking_started: - thinking_started = True - pending_content += '\n' + reasoning_content - else: - # 继续输出 reasoning_content - pending_content += reasoning_content - elif thinking_started and not thinking_ended and delta.get('content'): - # reasoning_content 结束,normal content 开始,添加 结束标签 - thinking_ended = True - pending_content += '\n\n' + delta.get('content') - - if delta.get('content') is not None: - pending_content += delta.get('content') - - if delta.get('tool_calls') is not None: - for tool_call in delta.get('tool_calls'): - if tool_call['id'] != '': - tool_id = tool_call['id'] - if tool_call['function']['name'] is not None: - tool_name = tool_call['function']['name'] - if tool_call['function']['arguments'] is None: - continue - tool_call['id'] = tool_id - tool_call['name'] = tool_name - for tc in tool_calls: - if tc['index'] == tool_call['index']: - tc['function']['arguments'] += tool_call['function']['arguments'] - break - else: - tool_calls.append(tool_call) - - if chunk.choices[0].finish_reason is not None: - break - message_delta['content'] = pending_content - message_delta['role'] = 'assistant' - - message_delta['tool_calls'] = tool_calls if tool_calls else None - return [message_delta] - - async def _make_msg( - self, - chat_completion: list[dict[str, typing.Any]], - ) -> provider_message.Message: - chatcmpl_message = chat_completion[0] - - # 确保 role 字段存在且不为 None - if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None: - chatcmpl_message['role'] = 'assistant' - - message = provider_message.Message(**chatcmpl_message) - - return message - - async def _closure( - self, - query: pipeline_query.Query, - req_messages: list[dict], - use_model: requester.RuntimeLLMModel, - use_funcs: list[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> tuple[provider_message.Message, dict]: - self.client.api_key = use_model.provider.token_mgr.get_token() - - args = {} - args['model'] = use_model.model_entity.name - - if use_funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) - - if tools: - args['tools'] = tools - - # 设置此次请求中的messages - messages = req_messages.copy() - - # 检查vision - for msg in messages: - if 'content' in msg and isinstance(msg['content'], list): - for me in msg['content']: - if me['type'] == 'image_base64': - me['image_url'] = {'url': me['image_base64']} - me['type'] = 'image_url' - del me['image_base64'] - - args['messages'] = messages - - # 发送请求 - resp = await self._req(query, args, extra_body=extra_args, remove_think=remove_think) - - # 处理请求结果 - message = await self._make_msg(resp) - - # ModelScope uses streaming, usage info not available - usage_info = {} - - return message, usage_info - - async def _req_stream( - self, - args: dict, - extra_body: dict = {}, - ) -> chat_completion.ChatCompletion: - async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body): - yield chunk - - async def _closure_stream( - self, - query: pipeline_query.Query, - req_messages: list[dict], - use_model: requester.RuntimeLLMModel, - use_funcs: list[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]: - self.client.api_key = use_model.provider.token_mgr.get_token() - - args = {} - args['model'] = use_model.model_entity.name - - if use_funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) - - if tools: - args['tools'] = tools - - # 设置此次请求中的messages - messages = req_messages.copy() - - # 检查vision - for msg in messages: - if 'content' in msg and isinstance(msg['content'], list): - for me in msg['content']: - if me['type'] == 'image_base64': - me['image_url'] = {'url': me['image_base64']} - me['type'] = 'image_url' - del me['image_base64'] - - args['messages'] = messages - args['stream'] = True - - # 流式处理状态 - # tool_calls_map: dict[str, provider_message.ToolCall] = {} - chunk_idx = 0 - thinking_started = False - thinking_ended = False - role = 'assistant' # 默认角色 - # accumulated_reasoning = '' # 仅用于判断何时结束思维链 - - async for chunk in self._req_stream(args, extra_body=extra_args): - # 解析 chunk 数据 - if hasattr(chunk, 'choices') and chunk.choices: - choice = chunk.choices[0] - delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {} - finish_reason = getattr(choice, 'finish_reason', None) - else: - delta = {} - finish_reason = None - - # 从第一个 chunk 获取 role,后续使用这个 role - if 'role' in delta and delta['role']: - role = delta['role'] - - # 获取增量内容 - delta_content = delta.get('content', '') - reasoning_content = delta.get('reasoning_content', '') - - # 处理 reasoning_content - if reasoning_content: - # accumulated_reasoning += reasoning_content - # 如果设置了 remove_think,跳过 reasoning_content - if remove_think: - chunk_idx += 1 - continue - - # 第一次出现 reasoning_content,添加 开始标签 - if not thinking_started: - thinking_started = True - delta_content = '\n' + reasoning_content - else: - # 继续输出 reasoning_content - delta_content = reasoning_content - elif thinking_started and not thinking_ended and delta_content: - # reasoning_content 结束,normal content 开始,添加 结束标签 - thinking_ended = True - delta_content = '\n\n' + delta_content - - # 处理 content 中已有的 标签(如果需要移除) - # if delta_content and remove_think and '' in delta_content: - # import re - # - # # 移除 标签及其内容 - # delta_content = re.sub(r'.*?', '', delta_content, flags=re.DOTALL) - - # 处理工具调用增量 - if delta.get('tool_calls'): - for tool_call in delta['tool_calls']: - if tool_call['id'] != '': - tool_id = tool_call['id'] - if tool_call['function']['name'] is not None: - tool_name = tool_call['function']['name'] - - if tool_call['type'] is None: - tool_call['type'] = 'function' - tool_call['id'] = tool_id - tool_call['function']['name'] = tool_name - tool_call['function']['arguments'] = ( - '' if tool_call['function']['arguments'] is None else tool_call['function']['arguments'] - ) - - # 跳过空的第一个 chunk(只有 role 没有内容) - if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'): - chunk_idx += 1 - continue - - # 构建 MessageChunk - 只包含增量内容 - chunk_data = { - 'role': role, - 'content': delta_content if delta_content else None, - 'tool_calls': delta.get('tool_calls'), - 'is_final': bool(finish_reason), - } - - # 移除 None 值 - chunk_data = {k: v for k, v in chunk_data.items() if v is not None} - - yield provider_message.MessageChunk(**chunk_data) - chunk_idx += 1 - # return - - async def invoke_llm( - self, - query: pipeline_query.Query, - model: entities.LLMModelInfo, - messages: typing.List[provider_message.Message], - funcs: typing.List[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> provider_message.Message: - req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 - for m in messages: - msg_dict = m.dict(exclude_none=True) - content = msg_dict.get('content') - if isinstance(content, list): - # 检查 content 列表中是否每个部分都是文本 - if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): - # 将所有文本部分合并为一个字符串 - msg_dict['content'] = '\n'.join(part['text'] for part in content) - req_messages.append(msg_dict) - - try: - return await self._closure( - query=query, - req_messages=req_messages, - use_model=model, - use_funcs=funcs, - extra_args=extra_args, - remove_think=remove_think, - ) - except asyncio.TimeoutError: - raise errors.RequesterError('请求超时') - except openai.BadRequestError as e: - if 'context_length_exceeded' in e.message: - raise errors.RequesterError(f'上文过长,请重置会话: {e.message}') - else: - raise errors.RequesterError(f'请求参数错误: {e.message}') - except openai.AuthenticationError as e: - raise errors.RequesterError(f'无效的 api-key: {e.message}') - except openai.NotFoundError as e: - raise errors.RequesterError(f'请求路径错误: {e.message}') - except openai.RateLimitError as e: - raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') - except openai.APIError as e: - raise errors.RequesterError(f'请求错误: {e.message}') - - async def invoke_llm_stream( - self, - query: pipeline_query.Query, - model: requester.RuntimeLLMModel, - messages: typing.List[provider_message.Message], - funcs: typing.List[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> provider_message.MessageChunk: - req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 - for m in messages: - msg_dict = m.dict(exclude_none=True) - content = msg_dict.get('content') - if isinstance(content, list): - # 检查 content 列表中是否每个部分都是文本 - if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): - # 将所有文本部分合并为一个字符串 - msg_dict['content'] = '\n'.join(part['text'] for part in content) - req_messages.append(msg_dict) - - try: - async for item in self._closure_stream( - query=query, - req_messages=req_messages, - use_model=model, - use_funcs=funcs, - extra_args=extra_args, - remove_think=remove_think, - ): - yield item - - except asyncio.TimeoutError: - raise errors.RequesterError('请求超时') - except openai.BadRequestError as e: - if 'context_length_exceeded' in e.message: - raise errors.RequesterError(f'上文过长,请重置会话: {e.message}') - else: - raise errors.RequesterError(f'请求参数错误: {e.message}') - except openai.AuthenticationError as e: - raise errors.RequesterError(f'无效的 api-key: {e.message}') - except openai.NotFoundError as e: - raise errors.RequesterError(f'请求路径错误: {e.message}') - except openai.RateLimitError as e: - raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') - except openai.APIError as e: - raise errors.RequesterError(f'请求错误: {e.message}') diff --git a/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml index 8d22002d..cc9859e6 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: 魔搭社区 icon: modelscope.svg spec: + litellm_provider: openai config: - name: base_url label: @@ -29,8 +30,11 @@ spec: type: int required: true default: 120 + alias: "modelscope ModelScope 魔搭 魔塔 摩搭 阿里 modelscope-aigc qwen bge" support_type: - llm + - text-embedding + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py deleted file mode 100644 index b6852963..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -import typing - - -from . import chatcmpl -from .. import requester -import langbot_plugin.api.entities.builtin.resource.tool as resource_tool -import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query -import langbot_plugin.api.entities.builtin.provider.message as provider_message - - -class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): - """Moonshot ChatCompletion API 请求器""" - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://api.moonshot.cn/v1', - 'timeout': 120, - } - - async def _closure( - self, - query: pipeline_query.Query, - req_messages: list[dict], - use_model: requester.RuntimeLLMModel, - use_funcs: list[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> tuple[provider_message.Message, dict]: - self.client.api_key = use_model.provider.token_mgr.get_token() - - args = {} - args['model'] = use_model.model_entity.name - - if use_funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) - - if tools: - args['tools'] = tools - - # 设置此次请求中的messages - messages = req_messages - - # deepseek 不支持多模态,把content都转换成纯文字 - for m in messages: - if 'content' in m and isinstance(m['content'], list): - m['content'] = ' '.join([c['text'] for c in m['content']]) - - # 删除空的,不知道干嘛的,直接删了。 - # messages = [m for m in messages if m["content"].strip() != "" and ('tool_calls' not in m or not m['tool_calls'])] - - args['messages'] = messages - - # 发送请求 - resp = await self._req(args, extra_body=extra_args) - - # 处理请求结果 - message = await self._make_msg(resp, remove_think) - - # Extract token usage from response - usage_info = {} - if hasattr(resp, 'usage') and resp.usage: - usage_info['input_tokens'] = resp.usage.prompt_tokens or 0 - usage_info['output_tokens'] = resp.usage.completion_tokens or 0 - usage_info['total_tokens'] = resp.usage.total_tokens or 0 - - return message, usage_info diff --git a/src/langbot/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml index 7a7e3060..50413f3b 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: 月之暗面 icon: moonshot.png spec: + litellm_provider: openai config: - name: base_url label: @@ -22,6 +23,7 @@ spec: type: integer required: true default: 120 + alias: "moonshot Moonshot 月之暗面 月暗 kimi Kimi 月之 暗面 moonshot-v1 k2" support_type: - llm provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/newapichatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/newapichatcmpl.py deleted file mode 100644 index 3c2bd3fb..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/newapichatcmpl.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import typing -import openai - -from . import chatcmpl - - -class NewAPIChatCompletions(chatcmpl.OpenAIChatCompletions): - """New API ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'http://localhost:3000/v1', - 'timeout': 120, - } diff --git a/src/langbot/pkg/provider/modelmgr/requesters/newapichatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/newapichatcmpl.yaml index e0f44e99..694af440 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/newapichatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/newapichatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: New API icon: newapi.png spec: + litellm_provider: openai config: - name: base_url label: @@ -22,9 +23,11 @@ spec: type: integer required: true default: 120 + alias: "newapi new-api New API one-api oneapi 中转 中转站 aggregator 聚合 网关 gateway rerank" support_type: - llm - text-embedding + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/ollamachat.py b/src/langbot/pkg/provider/modelmgr/requesters/ollamachat.py deleted file mode 100644 index 50f601d7..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/ollamachat.py +++ /dev/null @@ -1,314 +0,0 @@ -from __future__ import annotations - -import asyncio -import os -import typing -from typing import Union, Mapping, Any, AsyncIterator -import uuid -import json - -import ollama -import httpx - -from .. import errors, requester -import langbot_plugin.api.entities.builtin.resource.tool as resource_tool -import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query -import langbot_plugin.api.entities.builtin.provider.message as provider_message - -REQUESTER_NAME: str = 'ollama-chat' - - -class OllamaChatCompletions(requester.ProviderAPIRequester): - """Ollama平台 ChatCompletion API请求器""" - - client: ollama.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'http://127.0.0.1:11434', - 'timeout': 120, - } - - async def initialize(self): - os.environ['OLLAMA_HOST'] = self.requester_cfg['base_url'] - self.client = ollama.AsyncClient(timeout=self.requester_cfg['timeout']) - - def _infer_model_type(self, model_id: str) -> str: - normalized_model_id = (model_id or '').lower() - embedding_keywords = ('embedding', 'embed', 'bge-', 'e5-', 'm3e', 'gte-', 'text-embedding') - return 'embedding' if any(keyword in normalized_model_id for keyword in embedding_keywords) else 'llm' - - def _infer_model_abilities(self, item: dict[str, typing.Any], model_id: str) -> list[str]: - normalized_model_id = (model_id or '').lower() - abilities: set[str] = set() - details = item.get('details', {}) or {} - families = details.get('families', []) or [] - tokens = [normalized_model_id, str(details.get('family', '')).lower()] - tokens.extend(str(family).lower() for family in families) - - if any(keyword in token for token in tokens for keyword in ('vision', 'vl', 'omni', 'llava', 'ocr')): - abilities.add('vision') - if any(keyword in token for token in tokens for keyword in ('tool', 'function')): - abilities.add('func_call') - return sorted(abilities) - - async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]: - del api_key - models_url = f'{self.requester_cfg["base_url"].rstrip("/")}/api/tags' - - async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client: - response = await client.get(models_url) - response.raise_for_status() - payload = response.json() - - models: list[dict[str, typing.Any]] = [] - for item in payload.get('models', []): - model_id = item.get('model') or item.get('name') - if not model_id: - continue - models.append( - { - 'id': model_id, - 'name': item.get('name', model_id), - 'type': self._infer_model_type(model_id), - 'abilities': self._infer_model_abilities(item, model_id), - } - ) - - models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower())) - return { - 'models': models, - 'debug': { - 'request': { - 'method': 'GET', - 'url': models_url, - }, - 'response': payload, - }, - } - - async def _req( - self, - args: dict, - ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: - return await self.client.chat(**args) - - async def _closure( - self, - query: pipeline_query.Query, - req_messages: list[dict], - use_model: requester.RuntimeLLMModel, - use_funcs: list[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> provider_message.Message: - args = extra_args.copy() - args['model'] = use_model.model_entity.name - - messages: list[dict] = req_messages.copy() - for msg in messages: - if 'content' in msg and isinstance(msg['content'], list): - text_content: list = [] - image_urls: list = [] - for me in msg['content']: - if me['type'] == 'text': - text_content.append(me['text']) - elif me['type'] == 'image_base64': - image_urls.append(me['image_base64']) - - msg['content'] = '\n'.join(text_content) - msg['images'] = [url.split(',')[1] for url in image_urls] - if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict - for tool_call in msg['tool_calls']: - tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments']) - args['messages'] = messages - - args['tools'] = [] - if use_funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) - if tools: - args['tools'] = tools - - resp = await self._req(args) - message: provider_message.Message = await self._make_msg(resp) - return message - - async def _make_msg(self, chat_completions: ollama.ChatResponse) -> provider_message.Message: - message: ollama.Message = chat_completions.message - if message is None: - raise ValueError("chat_completions must contain a 'message' field") - - ret_msg: provider_message.Message = None - - if message.content is not None: - ret_msg = provider_message.Message(role='assistant', content=message.content) - if message.tool_calls is not None and len(message.tool_calls) > 0: - tool_calls: list[provider_message.ToolCall] = [] - - for tool_call in message.tool_calls: - tool_calls.append( - provider_message.ToolCall( - id=uuid.uuid4().hex, - type='function', - function=provider_message.FunctionCall( - name=tool_call.function.name, - arguments=json.dumps(tool_call.function.arguments), - ), - ) - ) - ret_msg.tool_calls = tool_calls - - return ret_msg - - async def _prepare_messages( - self, - messages: typing.List[provider_message.Message], - ) -> list[dict]: - """Prepare messages for Ollama API request.""" - req_messages: list = [] - for m in messages: - msg_dict: dict = m.dict(exclude_none=True) - content: Any = msg_dict.get('content') - if isinstance(content, list): - if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): - msg_dict['content'] = '\n'.join(part['text'] for part in content) - req_messages.append(msg_dict) - return req_messages - - async def invoke_llm( - self, - query: pipeline_query.Query, - model: requester.RuntimeLLMModel, - messages: typing.List[provider_message.Message], - funcs: typing.List[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> provider_message.Message: - req_messages = await self._prepare_messages(messages) - try: - return await self._closure( - query=query, - req_messages=req_messages, - use_model=model, - use_funcs=funcs, - extra_args=extra_args, - remove_think=remove_think, - ) - except asyncio.TimeoutError: - raise errors.RequesterError('请求超时') - - async def invoke_llm_stream( - self, - query: pipeline_query.Query, - model: requester.RuntimeLLMModel, - messages: typing.List[provider_message.Message], - funcs: typing.List[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> provider_message.MessageChunk: - req_messages = await self._prepare_messages(messages) - - try: - args = extra_args.copy() - args['model'] = model.model_entity.name - - # Process messages for Ollama format - msgs: list[dict] = req_messages.copy() - for msg in msgs: - if 'content' in msg and isinstance(msg['content'], list): - text_content: list = [] - image_urls: list = [] - for me in msg['content']: - if me['type'] == 'text': - text_content.append(me['text']) - elif me['type'] == 'image_base64': - image_urls.append(me['image_base64']) - msg['content'] = '\n'.join(text_content) - msg['images'] = [url.split(',')[1] for url in image_urls] - if 'tool_calls' in msg: - for tool_call in msg['tool_calls']: - tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments']) - args['messages'] = msgs - - args['tools'] = [] - if funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs) - if tools: - args['tools'] = tools - - args['stream'] = True - - chunk_idx = 0 - thinking_started = False - thinking_ended = False - role = 'assistant' - - async for chunk in await self.client.chat(**args): - message: ollama.Message = chunk.message - done = chunk.done - - delta_content = message.content or '' - reasoning_content = getattr(message, 'thinking', '') or '' - - # Handle reasoning/thinking content - if reasoning_content: - if remove_think: - chunk_idx += 1 - continue - - if not thinking_started: - thinking_started = True - delta_content = '\n' + reasoning_content - else: - delta_content = reasoning_content - elif thinking_started and not thinking_ended and delta_content: - thinking_ended = True - delta_content = '\n\n' + delta_content - - # Handle tool calls - tool_calls_data = None - if message.tool_calls: - tool_calls_data = [] - for tc in message.tool_calls: - tool_calls_data.append( - { - 'id': uuid.uuid4().hex, - 'type': 'function', - 'function': { - 'name': tc.function.name, - 'arguments': json.dumps(tc.function.arguments), - }, - } - ) - - # Skip empty first chunk - if chunk_idx == 0 and not delta_content and not reasoning_content and not tool_calls_data: - chunk_idx += 1 - continue - - chunk_data = { - 'role': role, - 'content': delta_content if delta_content else None, - 'tool_calls': tool_calls_data, - 'is_final': bool(done), - } - chunk_data = {k: v for k, v in chunk_data.items() if v is not None} - - yield provider_message.MessageChunk(**chunk_data) - chunk_idx += 1 - - except asyncio.TimeoutError: - raise errors.RequesterError('请求超时') - - async def invoke_embedding( - self, - model: requester.RuntimeEmbeddingModel, - input_text: list[str], - extra_args: dict[str, typing.Any] = {}, - ) -> list[list[float]]: - return ( - await self.client.embed( - model=model.model_entity.name, - input=input_text, - **extra_args, - ) - ).embeddings diff --git a/src/langbot/pkg/provider/modelmgr/requesters/ollamachat.yaml b/src/langbot/pkg/provider/modelmgr/requesters/ollamachat.yaml index a724f8f8..83e116c8 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/ollamachat.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/ollamachat.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: Ollama icon: ollama.svg spec: + litellm_provider: ollama config: - name: base_url label: @@ -22,6 +23,7 @@ spec: type: integer required: true default: 120 + alias: "ollama Ollama 本地 local 本地部署 self-hosted llama gguf 私有化" support_type: - llm - text-embedding diff --git a/src/langbot/pkg/provider/modelmgr/requesters/openrouterchatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/openrouterchatcmpl.py deleted file mode 100644 index 17b88431..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/openrouterchatcmpl.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations - -import typing -import openai - -from . import modelscopechatcmpl - - -class OpenRouterChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions): - """OpenRouter ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://openrouter.ai/api/v1', - 'timeout': 120, - } - - async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]: - original_base_url = self.requester_cfg.get('base_url', '') - self.requester_cfg['base_url'] = 'https://openrouter.ai/api/v1' - try: - return await super().scan_models(api_key) - finally: - self.requester_cfg['base_url'] = original_base_url diff --git a/src/langbot/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml index 71064dc0..9fe351ce 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: OpenRouter icon: openrouter.svg spec: + litellm_provider: openai config: - name: base_url label: @@ -22,6 +23,7 @@ spec: type: integer required: true default: 120 + alias: "openrouter OpenRouter open-router 中转 中转站 路由 aggregator gpt claude gemini llama" support_type: - llm - text-embedding diff --git a/src/langbot/pkg/provider/modelmgr/requesters/ppiochatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/ppiochatcmpl.py deleted file mode 100644 index 1836bd62..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/ppiochatcmpl.py +++ /dev/null @@ -1,208 +0,0 @@ -from __future__ import annotations - -import openai -import typing - -from . import chatcmpl -from .. import requester -import openai.types.chat.chat_completion as chat_completion -import re -import langbot_plugin.api.entities.builtin.provider.message as provider_message -import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query -import langbot_plugin.api.entities.builtin.resource.tool as resource_tool - - -class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): - """欧派云 ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://api.ppinfra.com/v3/openai', - 'timeout': 120, - } - - is_think: bool = False - - async def _make_msg( - self, - chat_completion: chat_completion.ChatCompletion, - remove_think: bool, - ) -> provider_message.Message: - chatcmpl_message = chat_completion.choices[0].message.model_dump() - # print(chatcmpl_message.keys(), chatcmpl_message.values()) - - # 确保 role 字段存在且不为 None - if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None: - chatcmpl_message['role'] = 'assistant' - - reasoning_content = chatcmpl_message['reasoning_content'] if 'reasoning_content' in chatcmpl_message else None - - # deepseek的reasoner模型 - chatcmpl_message['content'] = await self._process_thinking_content( - chatcmpl_message['content'], reasoning_content, remove_think - ) - - # 移除 reasoning_content 字段,避免传递给 Message - if 'reasoning_content' in chatcmpl_message: - del chatcmpl_message['reasoning_content'] - - message = provider_message.Message(**chatcmpl_message) - - return message - - async def _process_thinking_content( - self, - content: str, - reasoning_content: str = None, - remove_think: bool = False, - ) -> tuple[str, str]: - """处理思维链内容 - - Args: - content: 原始内容 - reasoning_content: reasoning_content 字段内容 - remove_think: 是否移除思维链 - - Returns: - 处理后的内容 - """ - if remove_think: - content = re.sub(r'.*?', '', content, flags=re.DOTALL) - else: - if reasoning_content is not None: - content = '\n' + reasoning_content + '\n\n' + content - return content - - async def _make_msg_chunk( - self, - delta: dict[str, typing.Any], - idx: int, - ) -> provider_message.MessageChunk: - # 处理流式chunk和完整响应的差异 - # print(chat_completion.choices[0]) - - # 确保 role 字段存在且不为 None - if 'role' not in delta or delta['role'] is None: - delta['role'] = 'assistant' - - reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None - - delta['content'] = '' if delta['content'] is None else delta['content'] - # print(reasoning_content) - - # deepseek的reasoner模型 - - if reasoning_content is not None: - delta['content'] += reasoning_content - - message = provider_message.MessageChunk(**delta) - - return message - - async def _closure_stream( - self, - query: pipeline_query.Query, - req_messages: list[dict], - use_model: requester.RuntimeLLMModel, - use_funcs: list[resource_tool.LLMTool] = None, - extra_args: dict[str, typing.Any] = {}, - remove_think: bool = False, - ) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]: - self.client.api_key = use_model.provider.token_mgr.get_token() - - args = {} - args['model'] = use_model.model_entity.name - - if use_funcs: - tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) - - if tools: - args['tools'] = tools - - # 设置此次请求中的messages - messages = req_messages.copy() - - # 检查vision - for msg in messages: - if 'content' in msg and isinstance(msg['content'], list): - for me in msg['content']: - if me['type'] == 'image_base64': - me['image_url'] = {'url': me['image_base64']} - me['type'] = 'image_url' - del me['image_base64'] - - args['messages'] = messages - args['stream'] = True - - # tool_calls_map: dict[str, provider_message.ToolCall] = {} - chunk_idx = 0 - thinking_started = False - thinking_ended = False - role = 'assistant' # 默认角色 - async for chunk in self._req_stream(args, extra_body=extra_args): - # 解析 chunk 数据 - if hasattr(chunk, 'choices') and chunk.choices: - choice = chunk.choices[0] - delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {} - finish_reason = getattr(choice, 'finish_reason', None) - else: - delta = {} - finish_reason = None - - # 从第一个 chunk 获取 role,后续使用这个 role - if 'role' in delta and delta['role']: - role = delta['role'] - - # 获取增量内容 - delta_content = delta.get('content', '') - # reasoning_content = delta.get('reasoning_content', '') - - if remove_think: - if delta['content'] is not None: - if '' in delta['content'] and not thinking_started and not thinking_ended: - thinking_started = True - continue - elif delta['content'] == r'' and not thinking_ended: - thinking_ended = True - continue - elif thinking_ended and delta['content'] == '\n\n' and thinking_started: - thinking_started = False - continue - elif thinking_started and not thinking_ended: - continue - - # delta_tool_calls = None - if delta.get('tool_calls'): - for tool_call in delta['tool_calls']: - if tool_call['id'] and tool_call['function']['name']: - tool_id = tool_call['id'] - tool_name = tool_call['function']['name'] - - if tool_call['id'] is None: - tool_call['id'] = tool_id - if tool_call['function']['name'] is None: - tool_call['function']['name'] = tool_name - if tool_call['function']['arguments'] is None: - tool_call['function']['arguments'] = '' - if tool_call['type'] is None: - tool_call['type'] = 'function' - - # 跳过空的第一个 chunk(只有 role 没有内容) - if chunk_idx == 0 and not delta_content and not delta.get('tool_calls'): - chunk_idx += 1 - continue - - # 构建 MessageChunk - 只包含增量内容 - chunk_data = { - 'role': role, - 'content': delta_content if delta_content else None, - 'tool_calls': delta.get('tool_calls'), - 'is_final': bool(finish_reason), - } - - # 移除 None 值 - chunk_data = {k: v for k, v in chunk_data.items() if v is not None} - - yield provider_message.MessageChunk(**chunk_data) - chunk_idx += 1 diff --git a/src/langbot/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml index 9e8eb1b0..79408fb5 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: 派欧云 icon: ppio.svg spec: + litellm_provider: openai config: - name: base_url label: @@ -29,9 +30,11 @@ spec: type: int required: true default: 120 + alias: "ppio PPIO 派欧 派欧云 paiou ppinfra 派欧算力 bge embedding rerank" support_type: - llm - text-embedding + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/qhaigcchatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/qhaigcchatcmpl.py deleted file mode 100644 index a68b6896..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/qhaigcchatcmpl.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import openai -import typing - -from . import chatcmpl - - -class QHAIGCChatCompletions(chatcmpl.OpenAIChatCompletions): - """启航 AI ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://api.qhaigc.com/v1', - 'timeout': 120, - } diff --git a/src/langbot/pkg/provider/modelmgr/requesters/qhaigcchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/qhaigcchatcmpl.yaml index 46ae1fad..28680c0f 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/qhaigcchatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/qhaigcchatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: 启航 AI icon: qhaigc.png spec: + litellm_provider: openai config: - name: base_url label: @@ -29,9 +30,11 @@ spec: type: int required: true default: 120 + alias: "qhaigc 青华 qinghua aigc 中转 中转站" support_type: - llm - text-embedding + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/qiniuchatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/qiniuchatcmpl.py index 0c7a940f..84c59b74 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/qiniuchatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/qiniuchatcmpl.py @@ -2,19 +2,16 @@ from __future__ import annotations import typing -import openai - -from . import chatcmpl +from . import litellmchat -class QiniuChatCompletions(chatcmpl.OpenAIChatCompletions): +class QiniuChatCompletions(litellmchat.LiteLLMRequester): """七牛云 ChatCompletion API 请求器""" - client: openai.AsyncClient - default_config: dict[str, typing.Any] = { 'base_url': 'https://api.qnaigc.com/v1', 'timeout': 120, + 'custom_llm_provider': 'openai', } async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/qiniuchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/qiniuchatcmpl.yaml index 5655d743..96a048fa 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/qiniuchatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/qiniuchatcmpl.yaml @@ -22,8 +22,11 @@ spec: type: integer required: true default: 120 + alias: "qiniu 七牛 七牛云 qiniu-cloud kodo ai推理 bge embedding rerank" support_type: - llm + - text-embedding + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/seekdbembed.yaml b/src/langbot/pkg/provider/modelmgr/requesters/seekdbembed.yaml index 1ff48b50..d9aedad3 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/seekdbembed.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/seekdbembed.yaml @@ -12,6 +12,7 @@ metadata: icon: seekdb.svg spec: config: [] + alias: "seekdb SeekDB seek-db 向量 vector embedding 嵌入 数据库" support_type: - text-embedding provider_category: builtin diff --git a/src/langbot/pkg/provider/modelmgr/requesters/shengsuanyun.py b/src/langbot/pkg/provider/modelmgr/requesters/shengsuanyun.py deleted file mode 100644 index 122eaf7d..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/shengsuanyun.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -import openai -import typing - -from . import chatcmpl -import openai.types.chat.chat_completion as chat_completion - - -class ShengSuanYunChatCompletions(chatcmpl.OpenAIChatCompletions): - """胜算云(ModelSpot.AI) ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://router.shengsuanyun.com/api/v1', - 'timeout': 120, - } - - async def _req( - self, - args: dict, - extra_body: dict = {}, - ) -> chat_completion.ChatCompletion: - return await self.client.chat.completions.create( - **args, - extra_body=extra_body, - extra_headers={ - 'HTTP-Referer': 'https://langbot.app', - 'X-Title': 'LangBot', - }, - ) diff --git a/src/langbot/pkg/provider/modelmgr/requesters/shengsuanyun.yaml b/src/langbot/pkg/provider/modelmgr/requesters/shengsuanyun.yaml index 77cf682c..ae54668d 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/shengsuanyun.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/shengsuanyun.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: 胜算云 icon: shengsuanyun.svg spec: + litellm_provider: openai config: - name: base_url label: @@ -29,9 +30,11 @@ spec: type: int required: true default: 120 + alias: "shengsuanyun 胜算云 胜算 sheng suan yun 算力 中转" support_type: - llm - text-embedding + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py deleted file mode 100644 index 3636d9d1..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import typing -import openai - -from . import chatcmpl - - -class SiliconFlowChatCompletions(chatcmpl.OpenAIChatCompletions): - """SiliconFlow ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://api.siliconflow.cn/v1', - 'timeout': 120, - } diff --git a/src/langbot/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml index 11a2ffa3..b4e5d736 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: 硅基流动 icon: siliconflow.svg spec: + litellm_provider: openai config: - name: base_url label: @@ -22,6 +23,7 @@ spec: type: integer required: true default: 120 + alias: "siliconflow SiliconFlow 硅基流动 硅基 silicon flow guiji bge BAAI embedding rerank qwen deepseek" support_type: - llm - text-embedding diff --git a/src/langbot/pkg/provider/modelmgr/requesters/spacechatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/spacechatcmpl.py deleted file mode 100644 index 91740a1f..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/spacechatcmpl.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import typing -import openai - -from . import chatcmpl - - -class LangBotSpaceChatCompletions(chatcmpl.OpenAIChatCompletions): - """LangBot Space ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://api.langbot.cloud/v1', - 'timeout': 120, - } diff --git a/src/langbot/pkg/provider/modelmgr/requesters/spacechatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/spacechatcmpl.yaml index 29c23a83..4bfdb98a 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/spacechatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/spacechatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: Space icon: space.webp spec: + litellm_provider: openai config: - name: base_url label: @@ -22,9 +23,11 @@ spec: type: integer required: true default: 120 + alias: "space LangBot Space langbot-space 官方 official 自有 内置 rerank embedding" support_type: - llm - text-embedding + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/tencent.svg b/src/langbot/pkg/provider/modelmgr/requesters/tencent.svg new file mode 100644 index 00000000..de32c1bf --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/tencent.svg @@ -0,0 +1,5 @@ + + + Tencent + Hunyuan + diff --git a/src/langbot/pkg/provider/modelmgr/requesters/tencentchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/tencentchatcmpl.yaml new file mode 100644 index 00000000..4e2d68bd --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/tencentchatcmpl.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: LLMAPIRequester +metadata: + name: tencent-chat-completions + label: + en_US: Tencent Hunyuan + zh_Hans: 腾讯混元 + icon: tencent.svg +spec: + litellm_provider: openai + config: + - name: base_url + label: + en_US: Base URL + zh_Hans: 基础 URL + type: string + required: true + default: https://hunyuan.tencentcloudapi.com/v1 + - name: timeout + label: + en_US: Timeout + zh_Hans: 超时时间 + type: integer + required: true + default: 120 + alias: "tencent 腾讯 腾讯云 hunyuan 混元 tencent-cloud txcloud 元宝" + support_type: + - llm + - text-embedding + - rerank + provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/together.svg b/src/langbot/pkg/provider/modelmgr/requesters/together.svg new file mode 100644 index 00000000..b6ce0f80 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/together.svg @@ -0,0 +1,5 @@ + + + Together + AI + diff --git a/src/langbot/pkg/provider/modelmgr/requesters/togetherchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/togetherchatcmpl.yaml new file mode 100644 index 00000000..c2869a24 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/togetherchatcmpl.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: LLMAPIRequester +metadata: + name: together-chat-completions + label: + en_US: Together AI + zh_Hans: Together AI + icon: together.svg +spec: + litellm_provider: together_ai + config: + - name: base_url + label: + en_US: Base URL + zh_Hans: 基础 URL + type: string + required: true + default: https://api.together.xyz/v1 + - name: timeout + label: + en_US: Timeout + zh_Hans: 超时时间 + type: integer + required: true + default: 120 + alias: "together Together together-ai togetherai 中转 llama qwen bge rerank embedding" + support_type: + - llm + - text-embedding + - rerank + provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/tokenpony.yaml b/src/langbot/pkg/provider/modelmgr/requesters/tokenpony.yaml index f160bdea..c8fba393 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/tokenpony.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/tokenpony.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: 小马算力 icon: tokenpony.svg spec: + litellm_provider: openai config: - name: base_url label: @@ -22,9 +23,11 @@ spec: type: integer required: true default: 120 + alias: "tokenpony TokenPony token-pony 小马 token 小马算力 中转" support_type: - llm - text-embedding + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/tokenponychatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/tokenponychatcmpl.py deleted file mode 100644 index 92311454..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/tokenponychatcmpl.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import typing -import openai - -from . import chatcmpl - - -class TokenPonyChatCompletions(chatcmpl.OpenAIChatCompletions): - """TokenPony ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://api.tokenpony.cn/v1', - 'timeout': 120, - } diff --git a/src/langbot/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py deleted file mode 100644 index 7eb68956..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import typing -import openai - -from . import chatcmpl - - -class VolcArkChatCompletions(chatcmpl.OpenAIChatCompletions): - """火山方舟大模型平台 ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://ark.cn-beijing.volces.com/api/v3', - 'timeout': 120, - } diff --git a/src/langbot/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml index e5c82657..fd709f0e 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: 火山方舟 icon: volcark.svg spec: + litellm_provider: openai config: - name: base_url label: @@ -22,8 +23,11 @@ spec: type: integer required: true default: 120 + alias: "volcark volcengine 火山 火山方舟 火山引擎 ark 方舟 字节 bytedance doubao 豆包 seed embedding rerank" support_type: - llm + - text-embedding + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/voyageairerank.yaml b/src/langbot/pkg/provider/modelmgr/requesters/voyageairerank.yaml index a47b8d47..c10b7e03 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/voyageairerank.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/voyageairerank.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: Voyage AI icon: voyageai.svg spec: + litellm_provider: openai config: - name: base_url label: @@ -22,6 +23,7 @@ spec: type: integer required: true default: 120 + alias: "voyage voyageai voyage-ai VoyageAI rerank 重排 reranker voyage-rerank embedding" support_type: - rerank provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/xaichatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/xaichatcmpl.py deleted file mode 100644 index db2022f1..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/xaichatcmpl.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import typing -import openai - -from . import chatcmpl - - -class XaiChatCompletions(chatcmpl.OpenAIChatCompletions): - """xAI ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://api.x.ai/v1', - 'timeout': 120, - } diff --git a/src/langbot/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml index 2e721d70..0d55f7ba 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: xAI icon: xai.svg spec: + litellm_provider: openai config: - name: base_url label: @@ -22,6 +23,7 @@ spec: type: integer required: true default: 120 + alias: "xai xAI x-ai grok Grok 马斯克 musk x.ai 格罗克" support_type: - llm provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/yi.svg b/src/langbot/pkg/provider/modelmgr/requesters/yi.svg new file mode 100644 index 00000000..8dc5e827 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/yi.svg @@ -0,0 +1,5 @@ + + + 01.AI + Yi + diff --git a/src/langbot/pkg/provider/modelmgr/requesters/yichatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/yichatcmpl.yaml new file mode 100644 index 00000000..e75d0cdf --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/yichatcmpl.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: LLMAPIRequester +metadata: + name: yi-chat-completions + label: + en_US: 01.AI Yi + zh_Hans: 零一万物 + icon: yi.svg +spec: + litellm_provider: openai + config: + - name: base_url + label: + en_US: Base URL + zh_Hans: 基础 URL + type: string + required: true + default: https://api.lingyiwanwu.com/v1 + - name: timeout + label: + en_US: Timeout + zh_Hans: 超时时间 + type: integer + required: true + default: 120 + alias: "yi 零一 零一万物 零一万 lingyiwanwu 01 01.ai 万智 yi-large yi-lightning embedding" + support_type: + - llm + - text-embedding + - rerank + provider_category: manufacturer diff --git a/src/langbot/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py deleted file mode 100644 index a1a07068..00000000 --- a/src/langbot/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import typing -import openai - -from . import chatcmpl - - -class ZhipuAIChatCompletions(chatcmpl.OpenAIChatCompletions): - """智谱AI ChatCompletion API 请求器""" - - client: openai.AsyncClient - - default_config: dict[str, typing.Any] = { - 'base_url': 'https://open.bigmodel.cn/api/paas/v4', - 'timeout': 120, - } diff --git a/src/langbot/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml index a4ebb2ec..a0dabf8d 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml @@ -7,6 +7,7 @@ metadata: zh_Hans: 智谱 AI icon: zhipuai.svg spec: + litellm_provider: openai config: - name: base_url label: @@ -22,8 +23,11 @@ spec: type: integer required: true default: 120 + alias: "zhipu zhipuai 智谱 智谱AI 智谱清言 glm GLM chatglm 清言 bigmodel embedding-3 rerank" support_type: - llm + - text-embedding + - rerank provider_category: manufacturer execution: python: diff --git a/src/langbot/pkg/provider/runners/localagent.py b/src/langbot/pkg/provider/runners/localagent.py index 710bce96..9a90ed47 100644 --- a/src/langbot/pkg/provider/runners/localagent.py +++ b/src/langbot/pkg/provider/runners/localagent.py @@ -42,6 +42,64 @@ SANDBOX_EXEC_SYSTEM_GUIDANCE = ( MAX_TOOL_CALL_ROUNDS = 128 +def _model_has_ability(model: modelmgr_requester.RuntimeLLMModel, ability: str) -> bool: + return ability in (model.model_entity.abilities or []) + + +class _StreamAccumulator: + """Accumulate streamed content and fragmented OpenAI-style tool calls.""" + + def __init__(self, msg_sequence: int = 0, initial_content: str | None = None): + self.tool_calls_map: dict[str, provider_message.ToolCall] = {} + self.msg_idx = 0 + self.accumulated_content = initial_content or '' + self.last_role = 'assistant' + self.msg_sequence = msg_sequence + + def add(self, msg: provider_message.MessageChunk) -> provider_message.MessageChunk | None: + self.msg_idx += 1 + + if msg.role: + self.last_role = msg.role + + if msg.content: + self.accumulated_content += msg.content + + if msg.tool_calls: + for tool_call in msg.tool_calls: + if tool_call.id not in self.tool_calls_map: + self.tool_calls_map[tool_call.id] = provider_message.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=provider_message.FunctionCall( + name=tool_call.function.name if tool_call.function else '', + arguments='', + ), + ) + if tool_call.function and tool_call.function.arguments: + self.tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + + if self.msg_idx % 8 == 0 or msg.is_final: + self.msg_sequence += 1 + return provider_message.MessageChunk( + role=self.last_role, + content=self.accumulated_content, + tool_calls=list(self.tool_calls_map.values()) if (self.tool_calls_map and msg.is_final) else None, + is_final=msg.is_final, + msg_sequence=self.msg_sequence, + ) + + return None + + def final_message(self) -> provider_message.MessageChunk: + return provider_message.MessageChunk( + role=self.last_role, + content=self.accumulated_content, + tool_calls=list(self.tool_calls_map.values()) if self.tool_calls_map else None, + msg_sequence=self.msg_sequence, + ) + + @runner.runner_class('local-agent') class LocalAgentRunner(runner.RequestRunner): """Local agent request runner""" @@ -106,7 +164,7 @@ class LocalAgentRunner(runner.RequestRunner): query, model, messages, - funcs if model.model_entity.abilities.__contains__('func_call') else [], + funcs if _model_has_ability(model, 'func_call') else [], extra_args=model.model_entity.extra_args, remove_think=remove_think, ) @@ -136,7 +194,7 @@ class LocalAgentRunner(runner.RequestRunner): query, model, messages, - funcs if model.model_entity.abilities.__contains__('func_call') else [], + funcs if _model_has_ability(model, 'func_call') else [], extra_args=model.model_entity.extra_args, remove_think=remove_think, ) @@ -322,11 +380,7 @@ class LocalAgentRunner(runner.RequestRunner): final_msg = msg else: # Streaming: invoke with fallback - tool_calls_map: dict[str, provider_message.ToolCall] = {} - msg_idx = 0 - accumulated_content = '' - last_role = 'assistant' - msg_sequence = 1 + stream_accumulator = _StreamAccumulator(msg_sequence=1) stream_src, use_llm_model = await self._invoke_stream_with_fallback( query, @@ -336,44 +390,12 @@ class LocalAgentRunner(runner.RequestRunner): remove_think, ) async for msg in stream_src: - msg_idx = msg_idx + 1 - - if msg.role: - last_role = msg.role - - if msg.content: - accumulated_content += msg.content - - if msg.tool_calls: - for tool_call in msg.tool_calls: - if tool_call.id not in tool_calls_map: - tool_calls_map[tool_call.id] = provider_message.ToolCall( - id=tool_call.id, - type=tool_call.type, - function=provider_message.FunctionCall( - name=tool_call.function.name if tool_call.function else '', arguments='' - ), - ) - if tool_call.function and tool_call.function.arguments: - tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments - - if msg_idx % 8 == 0 or msg.is_final: - msg_sequence += 1 - yield provider_message.MessageChunk( - role=last_role, - content=accumulated_content, - tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None, - is_final=msg.is_final, - msg_sequence=msg_sequence, - ) + chunk = stream_accumulator.add(msg) + if chunk: + yield chunk initial_response_emitted = True - final_msg = provider_message.MessageChunk( - role=last_role, - content=accumulated_content, - tool_calls=list(tool_calls_map.values()) if tool_calls_map else None, - msg_sequence=msg_sequence, - ) + final_msg = stream_accumulator.final_message() pending_tool_calls = final_msg.tool_calls first_content = final_msg.content @@ -459,69 +481,32 @@ class LocalAgentRunner(runner.RequestRunner): ) if is_stream: - tool_calls_map = {} - msg_idx = 0 - accumulated_content = '' - last_role = 'assistant' - msg_sequence = first_end_sequence + stream_accumulator = _StreamAccumulator( + msg_sequence=first_end_sequence, + initial_content=first_content, + ) tool_stream_src = use_llm_model.provider.invoke_llm_stream( query, use_llm_model, req_messages, - query.use_funcs if use_llm_model.model_entity.abilities.__contains__('func_call') else [], + query.use_funcs if _model_has_ability(use_llm_model, 'func_call') else [], extra_args=use_llm_model.model_entity.extra_args, remove_think=remove_think, ) async for msg in tool_stream_src: - msg_idx += 1 + chunk = stream_accumulator.add(msg) + if chunk: + yield chunk - if msg.role: - last_role = msg.role - - # Prepend first-round content on first chunk of tool-call round - if msg_idx == 1: - accumulated_content = first_content if first_content is not None else accumulated_content - - if msg.content: - accumulated_content += msg.content - - if msg.tool_calls: - for tool_call in msg.tool_calls: - if tool_call.id not in tool_calls_map: - tool_calls_map[tool_call.id] = provider_message.ToolCall( - id=tool_call.id, - type=tool_call.type, - function=provider_message.FunctionCall( - name=tool_call.function.name if tool_call.function else '', arguments='' - ), - ) - if tool_call.function and tool_call.function.arguments: - tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments - - if msg_idx % 8 == 0 or msg.is_final: - msg_sequence += 1 - yield provider_message.MessageChunk( - role=last_role, - content=accumulated_content, - tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None, - is_final=msg.is_final, - msg_sequence=msg_sequence, - ) - - final_msg = provider_message.MessageChunk( - role=last_role, - content=accumulated_content, - tool_calls=list(tool_calls_map.values()) if tool_calls_map else None, - msg_sequence=msg_sequence, - ) + final_msg = stream_accumulator.final_message() else: # Non-streaming: use committed model directly (no fallback in tool loop) msg = await use_llm_model.provider.invoke_llm( query, use_llm_model, req_messages, - query.use_funcs if use_llm_model.model_entity.abilities.__contains__('func_call') else [], + query.use_funcs if _model_has_ability(use_llm_model, 'func_call') else [], extra_args=use_llm_model.model_entity.extra_args, remove_think=remove_think, ) diff --git a/src/langbot/pkg/provider/tools/toolmgr.py b/src/langbot/pkg/provider/tools/toolmgr.py index 16177a0b..fd03b303 100644 --- a/src/langbot/pkg/provider/tools/toolmgr.py +++ b/src/langbot/pkg/provider/tools/toolmgr.py @@ -83,19 +83,6 @@ class ToolManager: return tools - async def generate_tools_for_anthropic(self, use_funcs: list[resource_tool.LLMTool]) -> list: - tools = [] - - for function in use_funcs: - function_schema = { - 'name': function.name, - 'description': function.description, - 'input_schema': function.parameters, - } - tools.append(function_schema) - - return tools - async def execute_func_call(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any: from langbot.pkg.telemetry import features as telemetry_features diff --git a/src/langbot/pkg/utils/constants.py b/src/langbot/pkg/utils/constants.py index 4fad9069..f97255ab 100644 --- a/src/langbot/pkg/utils/constants.py +++ b/src/langbot/pkg/utils/constants.py @@ -2,7 +2,7 @@ import langbot semantic_version = f'v{langbot.__version__}' -required_database_version = 25 +required_database_version = 26 """Tag the version of the database schema, used to check if the database needs to be migrated""" debug_mode = False diff --git a/tests/integration/persistence/test_migrations.py b/tests/integration/persistence/test_migrations.py index 03392c93..be3427a5 100644 --- a/tests/integration/persistence/test_migrations.py +++ b/tests/integration/persistence/test_migrations.py @@ -104,7 +104,7 @@ class TestSQLiteMigrationUpgrade: rev = await get_alembic_current(sqlite_engine) assert rev is not None, "Expected a revision after upgrade" # Head should be the latest migration - assert rev.startswith('0004'), f"Expected head to be 0004_*, got {rev}" + assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}" @pytest.mark.asyncio async def test_upgrade_idempotent(self, sqlite_engine): diff --git a/tests/integration/persistence/test_migrations_postgres.py b/tests/integration/persistence/test_migrations_postgres.py index 7867d4af..20f89215 100644 --- a/tests/integration/persistence/test_migrations_postgres.py +++ b/tests/integration/persistence/test_migrations_postgres.py @@ -150,8 +150,8 @@ class TestPostgreSQLMigrationUpgrade: # Verify revision rev = await get_alembic_current(postgres_engine) assert rev is not None, "Expected a revision after upgrade" - # Head should be the latest migration (0004 for current state) - assert rev.startswith('0004'), f"Expected head to be 0004_*, got {rev}" + # Head should be the latest migration (0005 for current state) + assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}" @pytest.mark.asyncio async def test_postgres_upgrade_idempotent( diff --git a/tests/unit_tests/api/service/test_model_service.py b/tests/unit_tests/api/service/test_model_service.py index 6e6d2598..a0ffc92d 100644 --- a/tests/unit_tests/api/service/test_model_service.py +++ b/tests/unit_tests/api/service/test_model_service.py @@ -23,6 +23,7 @@ from langbot.pkg.api.http.service.model import ( RerankModelsService, _parse_provider_api_keys, _runtime_model_data, + _validate_provider_supports, ) from langbot.pkg.entity.persistence.model import LLMModel, EmbeddingModel, RerankModel, ModelProvider @@ -35,6 +36,7 @@ def _create_mock_llm_model( name: str = 'Test LLM', provider_uuid: str = 'provider-uuid', abilities: list = None, + context_length: int | None = None, extra_args: dict = None, ) -> Mock: """Helper to create mock LLMModel entity.""" @@ -43,6 +45,7 @@ def _create_mock_llm_model( model.name = name model.provider_uuid = provider_uuid model.abilities = abilities or [] + model.context_length = context_length model.extra_args = extra_args or {} return model @@ -142,10 +145,12 @@ class TestRuntimeModelData: 'name': 'Model', 'provider_uuid': 'provider', 'abilities': ['vision'], + 'context_length': 128000, 'extra_args': {'temp': 0.7}, } result = _runtime_model_data('uuid', update_payload) assert result['abilities'] == ['vision'] + assert result['context_length'] == 128000 assert result['extra_args'] == {'temp': 0.7} @@ -188,7 +193,7 @@ class TestLLMModelsServiceGetLLMModels: ap = SimpleNamespace() ap.persistence_mgr = SimpleNamespace() - model = _create_mock_llm_model() + model = _create_mock_llm_model(context_length=128000) provider = _create_mock_provider() mock_model_result = _create_mock_result([model]) @@ -206,6 +211,7 @@ class TestLLMModelsServiceGetLLMModels: 'uuid': entity.uuid, 'name': entity.name, 'provider_uuid': entity.provider_uuid if hasattr(entity, 'provider_uuid') else None, + 'context_length': getattr(entity, 'context_length', None), 'api_keys': entity.api_keys if hasattr(entity, 'api_keys') else None, } ) @@ -218,6 +224,7 @@ class TestLLMModelsServiceGetLLMModels: # Verify assert len(result) == 1 assert result[0]['name'] == 'Test LLM' + assert result[0]['context_length'] == 128000 async def test_get_llm_models_hide_secret_keys(self): """Hides secret API keys when include_secret=False.""" @@ -265,7 +272,7 @@ class TestLLMModelsServiceGetLLMModel: ap = SimpleNamespace() ap.persistence_mgr = SimpleNamespace() - model = _create_mock_llm_model(model_uuid='found-uuid') + model = _create_mock_llm_model(model_uuid='found-uuid', context_length=128000) provider = _create_mock_provider() mock_model_result = _create_mock_result([], first_item=model) @@ -279,11 +286,12 @@ class TestLLMModelsServiceGetLLMModel: ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute) ap.persistence_mgr.serialize_model = Mock( - return_value={ - 'uuid': 'found-uuid', - 'name': 'Test LLM', - 'provider_uuid': 'provider-uuid', - 'provider': {'uuid': 'provider-uuid', 'api_keys': ['key']}, + side_effect=lambda model_cls, entity: { + 'uuid': entity.uuid, + 'name': entity.name, + 'provider_uuid': getattr(entity, 'provider_uuid', None), + 'context_length': getattr(entity, 'context_length', None), + 'api_keys': getattr(entity, 'api_keys', None), } ) @@ -295,6 +303,7 @@ class TestLLMModelsServiceGetLLMModel: # Verify assert result is not None assert result['uuid'] == 'found-uuid' + assert result['context_length'] == 128000 async def test_get_llm_model_not_found(self): """Returns None when model not found.""" @@ -402,6 +411,39 @@ class TestLLMModelsServiceCreateLLMModel: # Verify assert model_uuid == 'preserved-uuid' + async def test_create_llm_model_persists_context_length_as_column(self): + """Creates LLM model with context_length outside extra_args.""" + ap = SimpleNamespace() + ap.persistence_mgr = SimpleNamespace() + ap.model_mgr = SimpleNamespace() + ap.model_mgr.provider_dict = {'provider-uuid': Mock()} + ap.model_mgr.llm_models = [] + ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock()) + ap.pipeline_service = SimpleNamespace(update_pipeline=AsyncMock()) + + mock_result = _create_mock_result([]) + ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result) + + service = LLMModelsService(ap) + + await service.create_llm_model( + { + 'uuid': 'model-with-context', + 'name': 'Context Model', + 'provider_uuid': 'provider-uuid', + 'abilities': ['func_call'], + 'context_length': 128000, + 'extra_args': {'temperature': 0.2}, + }, + preserve_uuid=True, + auto_set_to_default_pipeline=False, + ) + + runtime_entity = ap.model_mgr.load_llm_model_with_provider.await_args.args[0] + assert runtime_entity.context_length == 128000 + assert runtime_entity.extra_args == {'temperature': 0.2} + assert 'context_length' not in runtime_entity.extra_args + async def test_create_llm_model_provider_not_found_raises_error(self): """Raises Exception when provider not found in runtime.""" # Setup @@ -512,6 +554,35 @@ class TestLLMModelsServiceUpdateLLMModel: 'provider_uuid': 'nonexistent-provider', }) + async def test_update_llm_model_reloads_context_length_as_column(self): + """Updates runtime model with context_length outside extra_args.""" + ap = SimpleNamespace() + ap.persistence_mgr = SimpleNamespace(execute_async=AsyncMock()) + ap.model_mgr = SimpleNamespace() + ap.model_mgr.provider_dict = {'provider-uuid': Mock()} + ap.model_mgr.llm_models = [] + ap.model_mgr.remove_llm_model = AsyncMock() + ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock()) + + service = LLMModelsService(ap) + + await service.update_llm_model( + 'existing-uuid', + { + 'name': 'Updated Name', + 'provider_uuid': 'provider-uuid', + 'abilities': ['vision'], + 'context_length': 64000, + 'extra_args': {'temperature': 0.4}, + }, + ) + + runtime_entity = ap.model_mgr.load_llm_model_with_provider.await_args.args[0] + assert runtime_entity.uuid == 'existing-uuid' + assert runtime_entity.context_length == 64000 + assert runtime_entity.extra_args == {'temperature': 0.4} + assert 'context_length' not in runtime_entity.extra_args + class TestLLMModelsServiceDeleteLLMModel: """Tests for LLMModelsService.delete_llm_model method.""" @@ -961,4 +1032,56 @@ class TestRerankModelsServiceGetRerankModelsByProvider: result = await service.get_rerank_models_by_provider('provider-uuid') # Verify - assert len(result) == 2 \ No newline at end of file + assert len(result) == 2 + + +class TestValidateProviderSupports: + """Tests for _validate_provider_supports guard.""" + + @staticmethod + def _make_ap(requester_name: str, support_type): + """Build a fake ap whose model_mgr resolves a manifest with support_type.""" + manifest = SimpleNamespace(spec={'support_type': support_type}) + runtime_provider = SimpleNamespace( + provider_entity=SimpleNamespace(requester=requester_name) + ) + model_mgr = SimpleNamespace( + provider_dict={'p1': runtime_provider}, + get_available_requester_manifest_by_name=lambda name: manifest + if name == requester_name + else None, + ) + return SimpleNamespace(model_mgr=model_mgr) + + async def test_allows_supported_type(self): + ap = self._make_ap('cohere-rerank', ['rerank']) + # Should not raise + await _validate_provider_supports(ap, 'p1', 'rerank') + + async def test_rejects_unsupported_type(self): + ap = self._make_ap('cohere-rerank', ['rerank']) + with pytest.raises(ValueError, match='does not support llm'): + await _validate_provider_supports(ap, 'p1', 'llm') + + async def test_allows_when_support_type_missing(self): + # Manifest without support_type must not block (backward compatible) + manifest = SimpleNamespace(spec={}) + runtime_provider = SimpleNamespace( + provider_entity=SimpleNamespace(requester='legacy') + ) + model_mgr = SimpleNamespace( + provider_dict={'p1': runtime_provider}, + get_available_requester_manifest_by_name=lambda name: manifest, + ) + ap = SimpleNamespace(model_mgr=model_mgr) + await _validate_provider_supports(ap, 'p1', 'rerank') + + async def test_allows_when_provider_unknown(self): + ap = self._make_ap('cohere-rerank', ['rerank']) + # Unknown provider uuid -> no entry -> no block + await _validate_provider_supports(ap, 'missing', 'llm') + + async def test_degrades_when_model_mgr_incomplete(self): + # A bare ap without a usable model_mgr must not raise (defensive) + ap = SimpleNamespace(model_mgr=SimpleNamespace()) + await _validate_provider_supports(ap, 'p1', 'llm') diff --git a/tests/unit_tests/provider/__init__.py b/tests/unit_tests/provider/__init__.py index 8b137891..758036b7 100644 --- a/tests/unit_tests/provider/__init__.py +++ b/tests/unit_tests/provider/__init__.py @@ -1 +1 @@ - +"""Provider requester tests""" diff --git a/tests/unit_tests/provider/requesters/test_anthropic_requester.py b/tests/unit_tests/provider/requesters/test_anthropic_requester.py deleted file mode 100644 index 54abb615..00000000 --- a/tests/unit_tests/provider/requesters/test_anthropic_requester.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Tests for AnthropicMessages requester. - -Tests config and pure utility methods. -""" - -from __future__ import annotations - -from unittest.mock import MagicMock - - -class TestAnthropicMessagesConfig: - """Tests for default config.""" - - def test_default_config_values(self): - """Check default_config.""" - from langbot.pkg.provider.modelmgr.requesters.anthropicmsgs import AnthropicMessages - - assert AnthropicMessages.default_config['base_url'] == 'https://api.anthropic.com' - assert AnthropicMessages.default_config['timeout'] == 120 - - def test_config_override(self): - """Config can override defaults.""" - from langbot.pkg.provider.modelmgr.requesters.anthropicmsgs import AnthropicMessages - - mock_app = MagicMock() - req = AnthropicMessages(mock_app, { - 'base_url': 'https://custom.anthropic.com', - 'timeout': 60, - }) - - assert req.requester_cfg['base_url'] == 'https://custom.anthropic.com' - assert req.requester_cfg['timeout'] == 60 \ No newline at end of file diff --git a/tests/unit_tests/provider/requesters/test_chatcmpl_errors_direct.py b/tests/unit_tests/provider/requesters/test_chatcmpl_errors_direct.py deleted file mode 100644 index c51476c2..00000000 --- a/tests/unit_tests/provider/requesters/test_chatcmpl_errors_direct.py +++ /dev/null @@ -1,247 +0,0 @@ -"""Tests for requester error handling - direct import version. - -Tests error handling branches by importing real packages and mocking -only the necessary dependencies. -""" - -from __future__ import annotations - -import asyncio -from unittest.mock import AsyncMock, MagicMock -import pytest -import openai # Import real openai package - -from langbot.pkg.provider.modelmgr.errors import RequesterError - - -class TestInvokeLLMErrorHandling: - """Tests for invoke_llm error handling branches.""" - - @pytest.fixture - def mock_app(self): - """Create mock Application.""" - app = MagicMock() - app.tool_mgr = MagicMock() - app.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=[]) - return app - - @pytest.fixture - def mock_model(self): - """Create mock RuntimeLLMModel.""" - model = MagicMock() - model.model_entity = MagicMock() - model.model_entity.name = 'gpt-4' - model.provider = MagicMock() - model.provider.token_mgr = MagicMock() - model.provider.token_mgr.get_token = MagicMock(return_value='test-key') - return model - - @pytest.fixture - def mock_message(self): - """Create mock provider message.""" - msg = MagicMock() - msg.dict = MagicMock(return_value={'role': 'user', 'content': 'test'}) - return msg - - @pytest.fixture - def requester_with_mocked_client(self, mock_app): - """Create requester with mocked OpenAI client.""" - from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions - - req = OpenAIChatCompletions(mock_app, { - 'base_url': 'https://api.openai.com/v1', - 'timeout': 120, - }) - - # Replace client with mock - req.client = MagicMock() - req.client.chat = MagicMock() - req.client.chat.completions = MagicMock() - req.client.chat.completions.create = AsyncMock() - - return req - - @pytest.mark.asyncio - async def test_timeout_error(self, requester_with_mocked_client, mock_model, mock_message): - """TimeoutError is wrapped as RequesterError.""" - requester_with_mocked_client.client.chat.completions.create = AsyncMock( - side_effect=asyncio.TimeoutError() - ) - - with pytest.raises(RequesterError) as exc: - await requester_with_mocked_client.invoke_llm( - query=None, - model=mock_model, - messages=[mock_message], - ) - - assert '超时' in str(exc.value) - - @pytest.mark.asyncio - async def test_bad_request_context_length(self, requester_with_mocked_client, mock_model, mock_message): - """BadRequestError with context_length_exceeded has special message.""" - error = openai.BadRequestError( - message='context_length_exceeded: max 4096', - response=MagicMock(status_code=400), - body={} - ) - requester_with_mocked_client.client.chat.completions.create = AsyncMock( - side_effect=error - ) - - with pytest.raises(RequesterError) as exc: - await requester_with_mocked_client.invoke_llm( - query=None, - model=mock_model, - messages=[mock_message], - ) - - assert '上文过长' in str(exc.value) - - @pytest.mark.asyncio - async def test_authentication_error(self, requester_with_mocked_client, mock_model, mock_message): - """AuthenticationError shows invalid api-key message.""" - error = openai.AuthenticationError( - message='Invalid API key', - response=MagicMock(status_code=401), - body={} - ) - requester_with_mocked_client.client.chat.completions.create = AsyncMock( - side_effect=error - ) - - with pytest.raises(RequesterError) as exc: - await requester_with_mocked_client.invoke_llm( - query=None, - model=mock_model, - messages=[mock_message], - ) - - assert 'api-key' in str(exc.value).lower() or '无效' in str(exc.value) - - @pytest.mark.asyncio - async def test_rate_limit_error(self, requester_with_mocked_client, mock_model, mock_message): - """RateLimitError shows rate limit message.""" - error = openai.RateLimitError( - message='Rate limit exceeded', - response=MagicMock(status_code=429), - body={} - ) - requester_with_mocked_client.client.chat.completions.create = AsyncMock( - side_effect=error - ) - - with pytest.raises(RequesterError) as exc: - await requester_with_mocked_client.invoke_llm( - query=None, - model=mock_model, - messages=[mock_message], - ) - - assert '频繁' in str(exc.value) or '余额' in str(exc.value) - - -class TestInvokeEmbeddingErrorHandling: - """Tests for invoke_embedding error handling.""" - - @pytest.fixture - def mock_app(self): - return MagicMock() - - @pytest.fixture - def mock_embedding_model(self): - model = MagicMock() - model.model_entity = MagicMock() - model.model_entity.name = 'text-embedding-ada-002' - model.model_entity.extra_args = {} - model.provider = MagicMock() - model.provider.token_mgr = MagicMock() - model.provider.token_mgr.get_token = MagicMock(return_value='test-key') - return model - - @pytest.fixture - def requester_with_mocked_client(self, mock_app): - from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions - - req = OpenAIChatCompletions(mock_app, {}) - req.client = MagicMock() - req.client.embeddings = MagicMock() - req.client.embeddings.create = AsyncMock() - - return req - - @pytest.mark.asyncio - async def test_embedding_timeout_error(self, requester_with_mocked_client, mock_embedding_model): - """TimeoutError in embedding request.""" - requester_with_mocked_client.client.embeddings.create = AsyncMock( - side_effect=asyncio.TimeoutError() - ) - - with pytest.raises(RequesterError) as exc: - await requester_with_mocked_client.invoke_embedding( - model=mock_embedding_model, - input_text=['test'], - ) - - assert '超时' in str(exc.value) - - @pytest.mark.asyncio - async def test_embedding_bad_request_error(self, requester_with_mocked_client, mock_embedding_model): - """BadRequestError in embedding request.""" - error = openai.BadRequestError( - message='Invalid model', - response=MagicMock(status_code=400), - body={} - ) - requester_with_mocked_client.client.embeddings.create = AsyncMock( - side_effect=error - ) - - with pytest.raises(RequesterError) as exc: - await requester_with_mocked_client.invoke_embedding( - model=mock_embedding_model, - input_text=['test'], - ) - - assert '参数' in str(exc.value) - - -class TestRequesterErrorClass: - """Tests for RequesterError.""" - - def test_error_message_prefix(self): - """RequesterError has '模型请求失败' prefix.""" - from langbot.pkg.provider.modelmgr.errors import RequesterError - - error = RequesterError('test error') - assert '模型请求失败' in str(error) - - def test_error_is_exception(self): - """RequesterError inherits Exception.""" - from langbot.pkg.provider.modelmgr.errors import RequesterError - - error = RequesterError('test') - assert isinstance(error, Exception) - - -class TestDefaultConfig: - """Tests for requester default config.""" - - def test_default_config(self): - """Check default_config values.""" - from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions - - assert OpenAIChatCompletions.default_config['base_url'] == 'https://api.openai.com/v1' - assert OpenAIChatCompletions.default_config['timeout'] == 120 - - def test_config_override(self): - """Config overrides defaults.""" - from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions - - req = OpenAIChatCompletions(MagicMock(), { - 'base_url': 'https://custom.com/v1', - 'timeout': 60, - }) - - assert req.requester_cfg['base_url'] == 'https://custom.com/v1' - assert req.requester_cfg['timeout'] == 60 diff --git a/tests/unit_tests/provider/requesters/test_chatcmpl_utils.py b/tests/unit_tests/provider/requesters/test_chatcmpl_utils.py deleted file mode 100644 index 38d9df1c..00000000 --- a/tests/unit_tests/provider/requesters/test_chatcmpl_utils.py +++ /dev/null @@ -1,340 +0,0 @@ -"""Tests for requester pure utility functions. - -Tests the helper methods in OpenAIChatCompletions that don't require network calls. -""" - -from __future__ import annotations - -from unittest.mock import MagicMock - -from tests.utils.import_isolation import isolated_sys_modules - - -class TestMaskApiKey: - """Tests for _mask_api_key method.""" - - def _create_requester_with_mocks(self): - """Create requester instance with mocked dependencies.""" - mocks = { - 'langbot.pkg.core.app': MagicMock(), - 'langbot_plugin.api.entities.builtin.resource.tool': MagicMock(), - 'langbot_plugin.api.entities.builtin.pipeline.query': MagicMock(), - 'langbot_plugin.api.entities.builtin.provider.message': MagicMock(), - 'langbot.pkg.provider.modelmgr.errors': MagicMock(), - } - - with isolated_sys_modules(mocks): - from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions - - mock_app = MagicMock() - requester = OpenAIChatCompletions(mock_app, {}) - return requester - - def test_mask_api_key_full(self): - """Mask a full API key.""" - requester = self._create_requester_with_mocks() - - result = requester._mask_api_key('sk-1234567890abcdef') - assert result == 'sk-1...cdef' - - def test_mask_api_key_short(self): - """Mask a short API key (<=8 chars).""" - requester = self._create_requester_with_mocks() - - result = requester._mask_api_key('short') - assert result == '****' - - def test_mask_api_key_empty(self): - """Empty API key returns empty string.""" - requester = self._create_requester_with_mocks() - - result = requester._mask_api_key('') - assert result == '' - - def test_mask_api_key_none(self): - """None API key returns empty string.""" - requester = self._create_requester_with_mocks() - - result = requester._mask_api_key(None) - assert result == '' - - def test_mask_api_key_exact_8_chars(self): - """API key with exactly 8 chars is masked as **** (<=8 threshold).""" - requester = self._create_requester_with_mocks() - - result = requester._mask_api_key('12345678') - assert result == '****' # <= 8 chars gets masked - - -class TestInferModelType: - """Tests for _infer_model_type method.""" - - def _create_requester_with_mocks(self): - mocks = { - 'langbot.pkg.core.app': MagicMock(), - 'langbot_plugin.api.entities.builtin.resource.tool': MagicMock(), - 'langbot_plugin.api.entities.builtin.pipeline.query': MagicMock(), - 'langbot_plugin.api.entities.builtin.provider.message': MagicMock(), - 'langbot.pkg.provider.modelmgr.errors': MagicMock(), - } - - with isolated_sys_modules(mocks): - from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions - - mock_app = MagicMock() - requester = OpenAIChatCompletions(mock_app, {}) - return requester - - def test_infer_embedding_from_name(self): - """Infer embedding type from model name.""" - requester = self._create_requester_with_mocks() - - assert requester._infer_model_type('text-embedding-ada-002') == 'embedding' - assert requester._infer_model_type('bge-large-en') == 'embedding' - assert requester._infer_model_type('e5-base') == 'embedding' - assert requester._infer_model_type('m3e-base') == 'embedding' - - def test_infer_llm_from_name(self): - """Infer LLM type from model name.""" - requester = self._create_requester_with_mocks() - - assert requester._infer_model_type('gpt-4') == 'llm' - assert requester._infer_model_type('claude-3-opus') == 'llm' - assert requester._infer_model_type('llama-2-70b') == 'llm' - - def test_infer_model_type_none_id(self): - """Handle None model_id.""" - requester = self._create_requester_with_mocks() - - result = requester._infer_model_type(None) - assert result == 'llm' # Default - - def test_infer_model_type_empty_id(self): - """Handle empty model_id.""" - requester = self._create_requester_with_mocks() - - result = requester._infer_model_type('') - assert result == 'llm' # Default - - -class TestNormalizeModalities: - """Tests for _normalize_modalities method.""" - - def _create_requester_with_mocks(self): - mocks = { - 'langbot.pkg.core.app': MagicMock(), - 'langbot_plugin.api.entities.builtin.resource.tool': MagicMock(), - 'langbot_plugin.api.entities.builtin.pipeline.query': MagicMock(), - 'langbot_plugin.api.entities.builtin.provider.message': MagicMock(), - 'langbot.pkg.provider.modelmgr.errors': MagicMock(), - } - - with isolated_sys_modules(mocks): - from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions - - mock_app = MagicMock() - requester = OpenAIChatCompletions(mock_app, {}) - return requester - - def test_normalize_string_modality(self): - """Normalize single string modality.""" - requester = self._create_requester_with_mocks() - - result = requester._normalize_modalities('text,image') - assert result == ['text', 'image'] - - def test_normalize_list_modalities(self): - """Normalize list of modalities.""" - requester = self._create_requester_with_mocks() - - result = requester._normalize_modalities(['text', 'image', 'audio']) - assert result == ['text', 'image', 'audio'] - - def test_normalize_dict_modalities(self): - """Normalize dict with nested modalities.""" - requester = self._create_requester_with_mocks() - - result = requester._normalize_modalities({'input': ['text'], 'output': ['text', 'image']}) - assert result == ['text', 'image'] - - def test_normalize_none(self): - """Handle None input.""" - requester = self._create_requester_with_mocks() - - result = requester._normalize_modalities(None) - assert result == [] - - def test_normalize_arrow_separator(self): - """Handle arrow separator in modality string.""" - requester = self._create_requester_with_mocks() - - result = requester._normalize_modalities('text->image') - assert result == ['text', 'image'] - - -class TestParseRerankResponse: - """Tests for _parse_rerank_response static method.""" - - def test_parse_cohere_jina_format(self): - """Parse Cohere/Jina/SiliconFlow format.""" - from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions - - data = { - 'results': [ - {'index': 0, 'relevance_score': 0.95}, - {'index': 1, 'relevance_score': 0.80}, - ] - } - - result = OpenAIChatCompletions._parse_rerank_response(data) - assert result == [ - {'index': 0, 'relevance_score': 0.95}, - {'index': 1, 'relevance_score': 0.80}, - ] - - def test_parse_voyage_format(self): - """Parse Voyage AI format.""" - from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions - - data = { - 'data': [ - {'index': 0, 'relevance_score': 0.90}, - {'index': 2, 'relevance_score': 0.75}, - ] - } - - result = OpenAIChatCompletions._parse_rerank_response(data) - assert result == [ - {'index': 0, 'relevance_score': 0.90}, - {'index': 2, 'relevance_score': 0.75}, - ] - - def test_parse_dashscope_format(self): - """Parse DashScope format.""" - from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions - - data = { - 'output': { - 'results': [ - {'index': 0, 'relevance_score': 0.85}, - ] - } - } - - result = OpenAIChatCompletions._parse_rerank_response(data) - assert result == [{'index': 0, 'relevance_score': 0.85}] - - def test_parse_unknown_format(self): - """Handle unknown format returns empty list.""" - from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions - - data = {'unknown_key': 'value'} - - result = OpenAIChatCompletions._parse_rerank_response(data) - assert result == [] - - def test_parse_empty_results(self): - """Handle empty results.""" - from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions - - data = {'results': []} - - result = OpenAIChatCompletions._parse_rerank_response(data) - assert result == [] - - -class TestExtractScanMetadata: - """Tests for _extract_scan_metadata method.""" - - def _create_requester_with_mocks(self): - mocks = { - 'langbot.pkg.core.app': MagicMock(), - 'langbot_plugin.api.entities.builtin.resource.tool': MagicMock(), - 'langbot_plugin.api.entities.builtin.pipeline.query': MagicMock(), - 'langbot_plugin.api.entities.builtin.provider.message': MagicMock(), - 'langbot.pkg.provider.modelmgr.errors': MagicMock(), - } - - with isolated_sys_modules(mocks): - from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions - - mock_app = MagicMock() - requester = OpenAIChatCompletions(mock_app, {}) - return requester - - def test_extract_basic_metadata(self): - """Extract basic model metadata.""" - requester = self._create_requester_with_mocks() - - item = { - 'id': 'gpt-4', - 'name': 'GPT-4 Turbo', - 'description': 'Most capable GPT-4 model', - 'context_length': 128000, - 'owned_by': 'openai', - } - - result = requester._extract_scan_metadata(item, 'gpt-4') - - assert result['display_name'] == 'GPT-4 Turbo' - assert result['description'] == 'Most capable GPT-4 model' - assert result['context_length'] == 128000 - assert result['owned_by'] == 'openai' - - def test_extract_metadata_missing_fields(self): - """Handle missing metadata fields.""" - requester = self._create_requester_with_mocks() - - item = {'id': 'unknown-model'} - - result = requester._extract_scan_metadata(item, 'unknown-model') - - assert result['display_name'] is None - assert result['description'] is None - assert result['context_length'] is None - assert result['owned_by'] is None - - def test_extract_metadata_top_provider_context(self): - """Extract context_length from top_provider.""" - requester = self._create_requester_with_mocks() - - item = { - 'id': 'model', - 'top_provider': { - 'context_length': 4096, - }, - } - - result = requester._extract_scan_metadata(item, 'model') - - assert result['context_length'] == 4096 - - def test_extract_metadata_empty_strings(self): - """Handle empty string values.""" - requester = self._create_requester_with_mocks() - - item = { - 'id': 'model', - 'name': '', # Empty name - 'description': ' ', # Whitespace only - 'owned_by': '', - } - - result = requester._extract_scan_metadata(item, 'model') - - assert result['display_name'] is None - assert result['description'] is None - assert result['owned_by'] is None - - def test_extract_metadata_name_matches_id(self): - """When name equals id, display_name is None.""" - requester = self._create_requester_with_mocks() - - item = { - 'id': 'gpt-4', - 'name': 'gpt-4', # Same as id - } - - result = requester._extract_scan_metadata(item, 'gpt-4') - - assert result['display_name'] is None diff --git a/tests/unit_tests/provider/requesters/test_ollama_requester.py b/tests/unit_tests/provider/requesters/test_ollama_requester.py deleted file mode 100644 index 993115ab..00000000 --- a/tests/unit_tests/provider/requesters/test_ollama_requester.py +++ /dev/null @@ -1,264 +0,0 @@ -"""Tests for OllamaChatCompletions requester. - -Tests model inference, payload construction, and error handling. -""" - -from __future__ import annotations - -import asyncio -from unittest.mock import AsyncMock, MagicMock -import pytest - -from langbot.pkg.provider.modelmgr.errors import RequesterError - - -class TestOllamaRequesterConfig: - """Tests for default config.""" - - def test_default_config_values(self): - """Check default_config.""" - from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions - - assert OllamaChatCompletions.default_config['base_url'] == 'http://127.0.0.1:11434' - assert OllamaChatCompletions.default_config['timeout'] == 120 - - def test_config_override(self): - """Config can override defaults.""" - from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions - - mock_app = MagicMock() - req = OllamaChatCompletions(mock_app, { - 'base_url': 'http://custom.ollama:11434', - 'timeout': 300, - }) - - assert req.requester_cfg['base_url'] == 'http://custom.ollama:11434' - assert req.requester_cfg['timeout'] == 300 - - -class TestOllamaInferModelType: - """Tests for _infer_model_type pure function.""" - - @pytest.fixture - def requester(self): - from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions - - return OllamaChatCompletions(MagicMock(), {}) - - def test_infer_embedding_from_name(self, requester): - """Embedding keywords return 'embedding'.""" - assert requester._infer_model_type('nomic-embed-text') == 'embedding' - assert requester._infer_model_type('bge-large') == 'embedding' - assert requester._infer_model_type('text-embedding') == 'embedding' - - def test_infer_llm_from_name(self, requester): - """Non-embedding keywords return 'llm'.""" - assert requester._infer_model_type('llama2') == 'llm' - assert requester._infer_model_type('mistral') == 'llm' - assert requester._infer_model_type('codellama') == 'llm' - - def test_infer_model_type_none(self, requester): - """None model_id returns 'llm'.""" - assert requester._infer_model_type(None) == 'llm' - - def test_infer_model_type_empty(self, requester): - """Empty model_id returns 'llm'.""" - assert requester._infer_model_type('') == 'llm' - - -class TestOllamaInferModelAbilities: - """Tests for _infer_model_abilities pure function.""" - - @pytest.fixture - def requester(self): - from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions - - return OllamaChatCompletions(MagicMock(), {}) - - def test_infer_vision_ability(self, requester): - """Vision keywords add 'vision' ability.""" - item = { - 'details': { - 'family': 'llava', - } - } - - abilities = requester._infer_model_abilities(item, 'llava-v1.5') - assert 'vision' in abilities - - def test_infer_vision_from_model_id(self, requester): - """Vision keywords in model_id add 'vision' ability.""" - item = {} - abilities = requester._infer_model_abilities(item, 'llava-7b') - assert 'vision' in abilities - - def test_infer_func_call_ability(self, requester): - """Tool/function keywords add 'func_call' ability.""" - item = { - 'details': { - 'families': ['tools'], - } - } - - abilities = requester._infer_model_abilities(item, 'model') - assert 'func_call' in abilities - - def test_infer_no_abilities(self, requester): - """No matching keywords returns empty abilities.""" - item = { - 'details': { - 'family': 'llama', - } - } - - abilities = requester._infer_model_abilities(item, 'llama-2') - assert len(abilities) == 0 - - def test_infer_multiple_abilities(self, requester): - """Multiple keywords can add multiple abilities.""" - item = { - 'details': { - 'family': 'vision', - 'families': ['tools'], - } - } - - abilities = requester._infer_model_abilities(item, 'vision-tool-model') - assert 'vision' in abilities - assert 'func_call' in abilities - - -class TestOllamaMakeMessage: - """Tests for _make_msg response parsing.""" - - @pytest.fixture - def requester(self): - from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions - - return OllamaChatCompletions(MagicMock(), {}) - - def _create_ollama_response(self, content, tool_calls=None): - """Helper to create mock ollama response.""" - import ollama - - mock_response = MagicMock(spec=ollama.ChatResponse) - mock_message = MagicMock(spec=ollama.Message) - mock_message.content = content - mock_message.tool_calls = tool_calls - mock_response.message = mock_message - - return mock_response - - @pytest.mark.asyncio - async def test_make_msg_text_content(self, requester): - """Text content is extracted.""" - mock_response = self._create_ollama_response('Hello world') - - result = await requester._make_msg(mock_response) - - assert result.content == 'Hello world' - assert result.role == 'assistant' - - @pytest.mark.asyncio - async def test_make_msg_with_tool_calls(self, requester): - """Tool calls are parsed.""" - mock_tool_call = MagicMock() - mock_tool_call.function = MagicMock() - mock_tool_call.function.name = 'get_weather' - mock_tool_call.function.arguments = {'location': 'Beijing'} - - mock_response = self._create_ollama_response('', tool_calls=[mock_tool_call]) - - result = await requester._make_msg(mock_response) - - assert result.tool_calls is not None - assert len(result.tool_calls) == 1 - assert result.tool_calls[0].function.name == 'get_weather' - # Arguments should be JSON string - assert isinstance(result.tool_calls[0].function.arguments, str) - - @pytest.mark.asyncio - async def test_make_msg_empty_message_raises(self, requester): - """Empty message raises ValueError.""" - mock_response = MagicMock() - mock_response.message = None - - with pytest.raises(ValueError, match='message'): - await requester._make_msg(mock_response) - - -class TestOllamaErrorHandling: - """Tests for error handling branches.""" - - @pytest.fixture - def mock_app(self): - app = MagicMock() - app.tool_mgr = MagicMock() - app.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=[]) - return app - - @pytest.fixture - def requester_with_mocked_client(self, mock_app): - from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions - - req = OllamaChatCompletions(mock_app, {}) - req.client = MagicMock() - req.client.chat = AsyncMock() - - return req - - @pytest.fixture - def mock_model(self): - model = MagicMock() - model.model_entity = MagicMock() - model.model_entity.name = 'llama2' - model.provider = MagicMock() - model.provider.token_mgr = MagicMock() - model.provider.token_mgr.get_token = MagicMock(return_value='') - return model - - @pytest.fixture - def mock_message(self): - msg = MagicMock() - msg.role = 'user' - msg.content = 'test' - msg.dict = MagicMock(return_value={'role': 'user', 'content': 'test'}) - return msg - - @pytest.mark.asyncio - async def test_timeout_error(self, requester_with_mocked_client, mock_model, mock_message): - """TimeoutError is converted to RequesterError.""" - requester_with_mocked_client.client.chat = AsyncMock(side_effect=asyncio.TimeoutError()) - - with pytest.raises(RequesterError) as exc: - await requester_with_mocked_client.invoke_llm( - query=None, - model=mock_model, - messages=[mock_message], - ) - - assert '超时' in str(exc.value) - - -class TestOllamaScanModels: - """Tests for scan_models method.""" - - @pytest.fixture - def mock_app(self): - return MagicMock() - - @pytest.fixture - def requester(self, mock_app): - from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions - - req = OllamaChatCompletions(mock_app, { - 'base_url': 'http://127.0.0.1:11434', - 'timeout': 120, - }) - return req - - def test_requester_name_constant(self): - """REQUESTER_NAME constant exists.""" - from langbot.pkg.provider.modelmgr.requesters.ollamachat import REQUESTER_NAME - - assert REQUESTER_NAME == 'ollama-chat' diff --git a/tests/unit_tests/provider/test_litellmchat.py b/tests/unit_tests/provider/test_litellmchat.py new file mode 100644 index 00000000..1ec12d82 --- /dev/null +++ b/tests/unit_tests/provider/test_litellmchat.py @@ -0,0 +1,1126 @@ +""" +Tests for LiteLLMRequester - unified requester for chat, embedding, and rerank. + +These tests verify: +- Parameter building and LiteLLM API calls +- Response processing and usage extraction +- Error handling and exception translation +- Model name building with provider prefix +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch + +import litellm + +from langbot.pkg.provider.modelmgr.requesters import litellmchat +from langbot.pkg.provider.modelmgr import errors + + +class MockRuntimeModel: + """Mock RuntimeLLMModel for testing""" + + def __init__(self, model_name: str = 'gpt-4o', api_key: str = 'test-key'): + self.model_entity = Mock() + self.model_entity.name = model_name + self.model_entity.extra_args = {} + self.provider = Mock() + self.provider.token_mgr = Mock() + self.provider.token_mgr.get_token = Mock(return_value=api_key) + + +class MockRuntimeEmbeddingModel: + """Mock RuntimeEmbeddingModel for testing""" + + def __init__(self, model_name: str = 'text-embedding-3-small', api_key: str = 'test-key'): + self.model_entity = Mock() + self.model_entity.name = model_name + self.model_entity.extra_args = {} + self.provider = Mock() + self.provider.token_mgr = Mock() + self.provider.token_mgr.get_token = Mock(return_value=api_key) + + +class MockRuntimeRerankModel: + """Mock RuntimeRerankModel for testing""" + + def __init__(self, model_name: str = 'cohere/rerank-english-v3.0', api_key: str = 'test-key'): + self.model_entity = Mock() + self.model_entity.name = model_name + self.model_entity.extra_args = {} + self.provider = Mock() + self.provider.token_mgr = Mock() + self.provider.token_mgr.get_token = Mock(return_value=api_key) + + +class TestBuildLiteLLMModelName: + """Test _build_litellm_model_name method""" + + def test_no_provider_prefix(self): + """Test model name without provider prefix""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': ''}) + result = requester._build_litellm_model_name('gpt-4o') + assert result == 'gpt-4o' + + def test_with_provider_prefix(self): + """Test model name with provider prefix""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'}) + result = requester._build_litellm_model_name('gpt-4o') + assert result == 'openai/gpt-4o' + + def test_avoid_duplicate_provider_prefix(self): + """Test model name with an existing matching provider prefix.""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'}) + result = requester._build_litellm_model_name('openai/gpt-4o') + assert result == 'openai/gpt-4o' + + def test_override_provider(self): + """Test override provider via parameter""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'openai'}) + result = requester._build_litellm_model_name('claude-3', custom_llm_provider='anthropic') + assert result == 'anthropic/claude-3' + + +class TestExtractUsage: + """Test _extract_usage method""" + + def test_extract_usage_with_data(self): + """Test extraction with valid usage data""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + response = Mock() + response.usage = Mock() + response.usage.prompt_tokens = 100 + response.usage.completion_tokens = 50 + response.usage.total_tokens = 150 + + result = requester._extract_usage(response) + + assert result['prompt_tokens'] == 100 + assert result['completion_tokens'] == 50 + assert result['total_tokens'] == 150 + + def test_extract_usage_with_zero_values(self): + """Test extraction when values are 0""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + response = Mock() + response.usage = Mock() + response.usage.prompt_tokens = 0 + response.usage.completion_tokens = 0 + response.usage.total_tokens = 0 + + result = requester._extract_usage(response) + + assert result['prompt_tokens'] == 0 + assert result['completion_tokens'] == 0 + + +class TestNormalizeUsage: + """Test _normalize_usage helper covering real-world usage shapes""" + + def test_none_usage(self): + """None usage -> all zeros (no crash)""" + result = litellmchat.LiteLLMRequester._normalize_usage(None) + assert result == {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0} + + def test_dict_usage(self): + """Usage given as a plain dict""" + result = litellmchat.LiteLLMRequester._normalize_usage( + {'prompt_tokens': 12, 'completion_tokens': 8, 'total_tokens': 20} + ) + assert result == {'prompt_tokens': 12, 'completion_tokens': 8, 'total_tokens': 20} + + def test_missing_total_is_derived(self): + """When total_tokens is absent/zero it is derived from prompt + completion""" + usage = Mock() + usage.prompt_tokens = 42 + usage.completion_tokens = 10 + usage.total_tokens = 0 + result = litellmchat.LiteLLMRequester._normalize_usage(usage) + assert result['total_tokens'] == 52 + + def test_partial_attrs_default_to_zero(self): + """Missing attributes default to 0 instead of raising""" + usage = Mock(spec=['prompt_tokens']) + usage.prompt_tokens = 5 + result = litellmchat.LiteLLMRequester._normalize_usage(usage) + assert result == {'prompt_tokens': 5, 'completion_tokens': 0, 'total_tokens': 5} + + +class TestInvokeLLMStreamUsage: + """Regression tests for streaming token usage capture. + + Real OpenAI-compatible gateways (e.g. new-api) send the final usage payload + in a chunk that still carries a (empty-delta) choice rather than an empty + `choices` list. The usage must be captured regardless, otherwise streamed + calls record 0 tokens. + """ + + def _make_chunk(self, *, content=None, tool_calls=None, finish_reason=None, usage=None, has_choice=True): + chunk = Mock() + if usage is not None: + chunk.usage = usage + else: + chunk.usage = None + if has_choice: + choice = Mock() + delta = Mock() + delta.model_dump = Mock( + return_value={'role': 'assistant', 'content': content, 'tool_calls': tool_calls} + ) + choice.delta = delta + choice.finish_reason = finish_reason + chunk.choices = [choice] + else: + chunk.choices = [] + return chunk + + @pytest.mark.asyncio + async def test_stream_usage_with_nonempty_choices(self): + """Usage chunk that still has a choice must populate _stream_usage.""" + import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + mock_ap = Mock() + mock_ap.tool_mgr = Mock() + mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=None) + requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={}) + model = MockRuntimeModel('gpt-4o', 'test-api-key') + + usage = Mock() + usage.prompt_tokens = 24 + usage.completion_tokens = 48 + usage.total_tokens = 72 + + chunks = [ + self._make_chunk(content='Hello'), + self._make_chunk(content=None, finish_reason='stop'), + # Final usage chunk WITH a non-empty (empty-delta) choice — the bug case. + self._make_chunk(content=None, usage=usage, has_choice=True), + ] + + async def _aiter(*args, **kwargs): + for c in chunks: + yield c + + query = Mock(spec=pipeline_query.Query) + query.variables = {} + + messages = [provider_message.Message(role='user', content='Hi')] + + with patch.object(litellmchat, 'acompletion', new=AsyncMock(side_effect=lambda **kw: _aiter())): + collected = [] + async for ch in requester.invoke_llm_stream(query=query, model=model, messages=messages): + collected.append(ch) + + assert '_stream_usage' in query.variables + assert query.variables['_stream_usage']['prompt_tokens'] == 24 + assert query.variables['_stream_usage']['completion_tokens'] == 48 + assert query.variables['_stream_usage']['total_tokens'] == 72 + + @pytest.mark.asyncio + async def test_stream_usage_with_empty_choices(self): + """Usage chunk with empty choices list must also populate _stream_usage.""" + import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + mock_ap = Mock() + mock_ap.tool_mgr = Mock() + mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=None) + requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={}) + model = MockRuntimeModel('gpt-4o', 'test-api-key') + + usage = Mock() + usage.prompt_tokens = 5 + usage.completion_tokens = 7 + usage.total_tokens = 12 + + chunks = [ + self._make_chunk(content='Hi there'), + self._make_chunk(content=None, finish_reason='stop'), + self._make_chunk(usage=usage, has_choice=False), + ] + + async def _aiter(*args, **kwargs): + for c in chunks: + yield c + + query = Mock(spec=pipeline_query.Query) + query.variables = {} + messages = [provider_message.Message(role='user', content='Hi')] + + with patch.object(litellmchat, 'acompletion', new=AsyncMock(side_effect=lambda **kw: _aiter())): + async for _ in requester.invoke_llm_stream(query=query, model=model, messages=messages): + pass + + assert query.variables['_stream_usage']['total_tokens'] == 12 + + @pytest.mark.asyncio + async def test_stream_tool_call_delta_missing_id_and_name(self): + """LiteLLM may stream tool-call argument deltas with id/name set to None.""" + import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + mock_ap = Mock() + mock_ap.tool_mgr = Mock() + mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock( + return_value=[{'type': 'function', 'function': {'name': 'qa_plugin_echo'}}] + ) + requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={}) + model = MockRuntimeModel('gpt-4o', 'test-api-key') + + chunks = [ + self._make_chunk( + tool_calls=[ + { + 'index': 0, + 'id': 'call_123', + 'type': 'function', + 'function': {'name': 'qa_plugin_echo', 'arguments': ''}, + } + ] + ), + self._make_chunk( + tool_calls=[ + { + 'index': 0, + 'id': None, + 'type': None, + 'function': {'name': None, 'arguments': '{"text":'}, + } + ] + ), + self._make_chunk( + tool_calls=[ + { + 'index': 0, + 'function': {'arguments': '"plugin-tool-ok"}'}, + } + ] + ), + self._make_chunk(finish_reason='tool_calls'), + ] + + async def _aiter(*args, **kwargs): + for c in chunks: + yield c + + query = Mock(spec=pipeline_query.Query) + query.variables = {} + messages = [provider_message.Message(role='user', content='Call the tool')] + funcs = [Mock()] + + with patch.object(litellmchat, 'acompletion', new=AsyncMock(side_effect=lambda **kw: _aiter())): + collected = [ + chunk async for chunk in requester.invoke_llm_stream( + query=query, + model=model, + messages=messages, + funcs=funcs, + ) + ] + + tool_chunks = [chunk for chunk in collected if chunk.tool_calls] + assert len(tool_chunks) == 3 + assert tool_chunks[1].tool_calls[0].id == 'call_123' + assert tool_chunks[1].tool_calls[0].function.name == 'qa_plugin_echo' + assert tool_chunks[1].tool_calls[0].function.arguments == '{"text":' + assert tool_chunks[2].tool_calls[0].function.arguments == '"plugin-tool-ok"}' + + +class TestProcessThinkingContent: + """Test _process_thinking_content method""" + + def test_no_thinking_markers(self): + """Test content without thinking markers""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + result = requester._process_thinking_content('Hello world', None, remove_think=True) + assert result == 'Hello world' + + def test_remove_thinking_markers(self): + """Test removing thinking markers when remove_think=True""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + content = 'CRETIRE_REASONING_BEGINkLet me think...CRETIRE_REASONING_ENDk The answer is 42.' + result = requester._process_thinking_content(content, None, remove_think=True) + assert result == 'The answer is 42.' + + def test_preserve_thinking_markers(self): + """Test preserving thinking markers when remove_think=False""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + content = 'CRETIRE_REASONING_BEGINkLet me think...CRETIRE_REASONING_ENDk The answer is 42.' + result = requester._process_thinking_content(content, None, remove_think=False) + assert 'CRETIRE_REASONING_BEGINk' in result + assert 'The answer is 42.' in result + + def test_empty_content(self): + """Test empty content""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + result = requester._process_thinking_content('', None, remove_think=True) + assert result == '' + + +class TestBuildCommonArgs: + """Test _build_common_args method""" + + def test_build_args_with_all_params(self): + """Test building args with all config params""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.openai.com/v1', + 'timeout': 60, + 'drop_params': True, + 'num_retries': 3, + 'api_version': '2024-01-01', + }, + ) + + args = {} + requester._build_common_args(args) + + assert args['api_base'] == 'https://api.openai.com/v1' + assert args['timeout'] == 60 + assert args['drop_params'] == True + assert args['num_retries'] == 3 + assert args['api_version'] == '2024-01-01' + + def test_build_args_without_retry_params(self): + """Test building args without retry params for embedding/rerank""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.openai.com/v1', + 'timeout': 60, + 'num_retries': 3, + }, + ) + + args = {} + requester._build_common_args(args, include_retry_params=False) + + assert args['api_base'] == 'https://api.openai.com/v1' + assert args['timeout'] == 60 + assert 'num_retries' not in args + + +class TestHandleLiteLLMError: + """Test _handle_litellm_error method""" + + def test_bad_request_error(self): + """Test BadRequestError translation""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + # Create proper LiteLLM exception with required args + error = litellm.BadRequestError(message='test error', model='gpt-4o', llm_provider='openai') + + with pytest.raises(errors.RequesterError) as exc_info: + requester._handle_litellm_error(error) + + assert '请求参数错误' in str(exc_info.value) + + def test_authentication_error(self): + """Test AuthenticationError translation""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + error = litellm.AuthenticationError(message='invalid key', model='gpt-4o', llm_provider='openai') + + with pytest.raises(errors.RequesterError) as exc_info: + requester._handle_litellm_error(error) + + assert 'API key 无效' in str(exc_info.value) + + def test_rate_limit_error(self): + """Test RateLimitError translation""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + error = litellm.RateLimitError(message='rate limited', model='gpt-4o', llm_provider='openai') + + with pytest.raises(errors.RequesterError) as exc_info: + requester._handle_litellm_error(error) + + assert '请求过于频繁' in str(exc_info.value) + + def test_timeout_error(self): + """Test Timeout translation""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + error = litellm.Timeout(message='timeout', model='gpt-4o', llm_provider='openai') + + with pytest.raises(errors.RequesterError) as exc_info: + requester._handle_litellm_error(error) + + assert '请求超时' in str(exc_info.value) + + def test_context_window_error(self): + """Test ContextWindowExceededError translation""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + error = litellm.ContextWindowExceededError(message='context too long', model='gpt-4o', llm_provider='openai') + + with pytest.raises(errors.RequesterError) as exc_info: + requester._handle_litellm_error(error) + + assert '上下文长度超限' in str(exc_info.value) + + def test_unknown_error(self): + """Test unknown error translation""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + with pytest.raises(errors.RequesterError) as exc_info: + requester._handle_litellm_error(Exception('unknown')) + + assert '未知错误' in str(exc_info.value) + + +class TestInvokeLLM: + """Test invoke_llm method""" + + @pytest.mark.asyncio + async def test_invoke_llm_basic(self): + """Test basic LLM invocation""" + mock_ap = Mock() + mock_ap.tool_mgr = Mock() + mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=None) + + requester = litellmchat.LiteLLMRequester( + ap=mock_ap, + config={ + 'base_url': 'https://api.openai.com/v1', + 'timeout': 60, + }, + ) + + model = MockRuntimeModel('gpt-4o', 'test-api-key') + + # Mock LiteLLM response + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message = Mock() + mock_response.choices[0].message.model_dump = Mock( + return_value={ + 'role': 'assistant', + 'content': 'Hello! How can I help you?', + } + ) + mock_response.usage = Mock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 20 + mock_response.usage.total_tokens = 30 + + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + messages = [provider_message.Message(role='user', content='Hello')] + + # Patch acompletion at the import location + with patch.object(litellmchat, 'acompletion', new_callable=AsyncMock, return_value=mock_response): + result_msg, usage = await requester.invoke_llm( + query=None, + model=model, + messages=messages, + ) + + assert result_msg.role == 'assistant' + assert result_msg.content == 'Hello! How can I help you?' + assert usage['prompt_tokens'] == 10 + assert usage['completion_tokens'] == 20 + + @pytest.mark.asyncio + async def test_invoke_llm_with_tools(self): + """Test LLM invocation with function calling""" + mock_ap = Mock() + mock_ap.tool_mgr = Mock() + mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock( + return_value=[{'type': 'function', 'function': {'name': 'get_weather'}}] + ) + + requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={}) + + model = MockRuntimeModel('gpt-4o', 'test-api-key') + + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message = Mock() + mock_response.choices[0].message.model_dump = Mock( + return_value={ + 'role': 'assistant', + 'content': None, + 'tool_calls': [ + {'id': 'call_123', 'type': 'function', 'function': {'name': 'get_weather', 'arguments': '{}'}} + ], + } + ) + mock_response.usage = Mock() + mock_response.usage.prompt_tokens = 15 + mock_response.usage.completion_tokens = 10 + mock_response.usage.total_tokens = 25 + + import langbot_plugin.api.entities.builtin.resource.tool as resource_tool + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + messages = [provider_message.Message(role='user', content='What is the weather?')] + # Create proper LLMTool with all required fields + funcs = [Mock(spec=resource_tool.LLMTool)] + funcs[0].name = 'get_weather' + funcs[0].description = 'Get weather' + + with patch.object(litellmchat, 'acompletion', new_callable=AsyncMock, return_value=mock_response): + result_msg, usage = await requester.invoke_llm( + query=None, + model=model, + messages=messages, + funcs=funcs, + ) + + assert result_msg.tool_calls is not None + called_kwargs = litellmchat.acompletion.await_args.kwargs + assert called_kwargs['tools'] == [{'type': 'function', 'function': {'name': 'get_weather'}}] + assert called_kwargs['tool_choice'] == 'auto' + + @pytest.mark.asyncio + async def test_build_completion_args_preserves_explicit_tool_choice(self): + """Model extra args can override the default auto tool choice.""" + mock_ap = Mock() + mock_ap.tool_mgr = Mock() + mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock( + return_value=[{'type': 'function', 'function': {'name': 'get_weather'}}] + ) + + requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={}) + model = MockRuntimeModel('gpt-4o', 'test-api-key') + model.model_entity.extra_args = {'tool_choice': 'required'} + + import langbot_plugin.api.entities.builtin.resource.tool as resource_tool + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + funcs = [Mock(spec=resource_tool.LLMTool)] + messages = [provider_message.Message(role='user', content='What is the weather?')] + + args = await requester._build_completion_args(model, messages, funcs) + + assert args['tool_choice'] == 'required' + + @pytest.mark.asyncio + async def test_invoke_llm_error_handling(self): + """Test LLM invocation error handling""" + mock_ap = Mock() + mock_ap.tool_mgr = Mock() + mock_ap.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=None) + + requester = litellmchat.LiteLLMRequester(ap=mock_ap, config={}) + + model = MockRuntimeModel('gpt-4o', 'test-api-key') + + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + messages = [provider_message.Message(role='user', content='Hello')] + + error = litellm.AuthenticationError(message='invalid key', model='gpt-4o', llm_provider='openai') + + with patch.object(litellmchat, 'acompletion', new_callable=AsyncMock, side_effect=error): + with pytest.raises(errors.RequesterError) as exc_info: + await requester.invoke_llm( + query=None, + model=model, + messages=messages, + ) + + assert 'API key 无效' in str(exc_info.value) + + +class TestInvokeEmbedding: + """Test invoke_embedding method""" + + @pytest.mark.asyncio + async def test_invoke_embedding_basic(self): + """Test basic embedding invocation""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.openai.com/v1', + }, + ) + + model = MockRuntimeEmbeddingModel('text-embedding-3-small', 'test-api-key') + + # Mock LiteLLM embedding response + mock_response = Mock() + mock_response.data = [ + Mock(embedding=[0.1, 0.2, 0.3]), + Mock(embedding=[0.4, 0.5, 0.6]), + ] + mock_response.usage = Mock() + mock_response.usage.prompt_tokens = 20 + mock_response.usage.completion_tokens = 0 + mock_response.usage.total_tokens = 20 + + with patch.object(litellmchat, 'aembedding', new_callable=AsyncMock, return_value=mock_response): + embeddings, usage = await requester.invoke_embedding( + model=model, + input_text=['Hello', 'World'], + ) + + assert len(embeddings) == 2 + assert embeddings[0] == [0.1, 0.2, 0.3] + assert embeddings[1] == [0.4, 0.5, 0.6] + assert usage['prompt_tokens'] == 20 + + +class TestInvokeRerank: + """Test invoke_rerank method""" + + @pytest.mark.asyncio + async def test_invoke_rerank_basic(self): + """Test basic rerank invocation""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.cohere.ai', + 'custom_llm_provider': 'cohere', + }, + ) + + model = MockRuntimeRerankModel('rerank-english-v3.0', 'test-api-key') + + # Mock LiteLLM rerank response + mock_response = Mock() + mock_response.results = [ + {'index': 0, 'relevance_score': 0.95}, + {'index': 1, 'relevance_score': 0.3}, + {'index': 2, 'relevance_score': 0.8}, + ] + + with patch.object(litellmchat, 'arerank', new_callable=AsyncMock, return_value=mock_response): + results = await requester.invoke_rerank( + model=model, + query='What is the capital of France?', + documents=['Paris is the capital.', 'London is a city.', 'France is in Europe.'], + ) + + assert len(results) == 3 + # Scores should be normalized + assert results[0]['index'] == 0 + assert results[0]['relevance_score'] >= 0 and results[0]['relevance_score'] <= 1 + + @pytest.mark.asyncio + async def test_invoke_rerank_normalization(self): + """Test rerank score normalization""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'cohere'}) + + model = MockRuntimeRerankModel('rerank-english-v3.0', 'test-api-key') + + # Mock response with varying scores + mock_response = Mock() + mock_response.results = [ + {'index': 0, 'relevance_score': 0.9}, + {'index': 1, 'relevance_score': 0.1}, + ] + + with patch.object(litellmchat, 'arerank', new_callable=AsyncMock, return_value=mock_response): + results = await requester.invoke_rerank( + model=model, + query='test query', + documents=['doc1', 'doc2'], + ) + + # After normalization: 0.9 -> 1.0, 0.1 -> 0.0 + assert results[0]['relevance_score'] == 1.0 + assert results[1]['relevance_score'] == 0.0 + + @pytest.mark.asyncio + async def test_invoke_rerank_single_document(self): + """Test rerank with single document (no normalization needed)""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={'custom_llm_provider': 'cohere'}) + + model = MockRuntimeRerankModel('rerank-english-v3.0', 'test-api-key') + + mock_response = Mock() + mock_response.results = [ + {'index': 0, 'relevance_score': 0.5}, + ] + + with patch.object(litellmchat, 'arerank', new_callable=AsyncMock, return_value=mock_response): + results = await requester.invoke_rerank( + model=model, + query='test query', + documents=['doc1'], + ) + + assert len(results) == 1 + # Single score stays as is (min==max, no normalization) + assert results[0]['relevance_score'] == 0.5 + + @pytest.mark.asyncio + async def test_invoke_rerank_openai_compatible_http(self): + """OpenAI-compatible gateways (newapi/one-api/vLLM) must use the HTTP /v1/rerank + endpoint instead of litellm.arerank, which rejects the 'openai' provider.""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://gateway.example.com/v1', + 'custom_llm_provider': 'openai', + }, + ) + + model = MockRuntimeRerankModel('bge-reranker-v2-m3', 'test-api-key') + + # Mock the standard Jina/Cohere-style /v1/rerank HTTP response + mock_resp = Mock() + mock_resp.raise_for_status = Mock() + mock_resp.json = Mock( + return_value={ + 'results': [ + {'index': 0, 'relevance_score': 0.9}, + {'index': 1, 'relevance_score': 0.1}, + ] + } + ) + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch('httpx.AsyncClient', return_value=mock_client): + # arerank must NOT be called on the openai-compatible path + with patch.object( + litellmchat, 'arerank', new_callable=AsyncMock, + side_effect=AssertionError('arerank must not be used for openai-compatible provider'), + ): + results = await requester.invoke_rerank( + model=model, + query='test query', + documents=['doc1', 'doc2'], + ) + + # Hit the standard rerank endpoint + called_url = mock_client.post.call_args[0][0] + assert called_url == 'https://gateway.example.com/v1/rerank' + # Payload carries the raw model name, query and documents + payload = mock_client.post.call_args[1]['json'] + assert payload['model'] == 'bge-reranker-v2-m3' + assert payload['query'] == 'test query' + assert payload['documents'] == ['doc1', 'doc2'] + # Scores normalized: 0.9 -> 1.0, 0.1 -> 0.0 + assert results[0]['relevance_score'] == 1.0 + assert results[1]['relevance_score'] == 0.0 + + @pytest.mark.asyncio + async def test_invoke_rerank_openai_compatible_score_alias(self): + """Some gateways return 'score' instead of 'relevance_score'; both must work.""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://gateway.example.com/v1', + 'custom_llm_provider': 'openai', + }, + ) + + model = MockRuntimeRerankModel('bge-reranker-v2-m3', 'test-api-key') + + mock_resp = Mock() + mock_resp.raise_for_status = Mock() + mock_resp.json = Mock( + return_value={ + 'results': [ + {'index': 0, 'score': 0.8}, + {'index': 1, 'score': 0.2}, + ] + } + ) + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch('httpx.AsyncClient', return_value=mock_client): + results = await requester.invoke_rerank( + model=model, + query='test query', + documents=['doc1', 'doc2'], + ) + + assert results[0]['relevance_score'] == 1.0 + assert results[1]['relevance_score'] == 0.0 + + +class TestConvertMessages: + """Test _convert_messages method""" + + def test_convert_simple_message(self): + """Test converting simple text message""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + messages = [provider_message.Message(role='user', content='Hello')] + result = requester._convert_messages(messages) + + assert len(result) == 1 + assert result[0]['role'] == 'user' + assert result[0]['content'] == 'Hello' + + def test_convert_message_with_image_base64(self): + """Test converting message with image_base64 content""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + messages = [ + provider_message.Message( + role='user', + content=[ + {'type': 'text', 'text': 'What is in this image?'}, + {'type': 'image_base64', 'image_base64': 'data:image/png;base64,abc123'}, + ], + ) + ] + result = requester._convert_messages(messages) + + assert len(result) == 1 + content = result[0]['content'] + assert isinstance(content, list) + # Check image_base64 converted to image_url + image_part = [p for p in content if p.get('type') == 'image_url'][0] + assert 'image_url' in image_part + assert image_part['image_url']['url'] == 'data:image/png;base64,abc123' + + def test_convert_message_with_multiple_text_parts(self): + """Test converting message with multiple text parts (LiteLLM handles this)""" + requester = litellmchat.LiteLLMRequester(ap=Mock(), config={}) + + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + messages = [ + provider_message.Message( + role='user', + content=[ + {'type': 'text', 'text': 'Hello'}, + {'type': 'text', 'text': 'World'}, + ], + ) + ] + result = requester._convert_messages(messages) + + assert len(result) == 1 + # LiteLLM handles multiple text parts, we pass them through + assert isinstance(result[0]['content'], list) + + +class TestScanModels: + """Test scan_models method""" + + @pytest.mark.asyncio + async def test_scan_models_basic(self): + """Test basic model scanning""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.openai.com/v1', + 'timeout': 60, + }, + ) + + # Mock httpx response + mock_response = Mock() + mock_response.json = Mock( + return_value={ + 'data': [ + {'id': 'gpt-4o'}, + {'id': 'text-embedding-3-small'}, + {'id': 'gpt-3.5-turbo'}, + ] + } + ) + mock_response.raise_for_status = Mock() + + with patch('httpx.AsyncClient') as mock_client: + mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock()) + mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response) + + result = await requester.scan_models(api_key='test-key') + + assert 'models' in result + assert len(result['models']) == 3 + # Check LLM models are first + assert result['models'][0]['type'] == 'llm' + # Check embedding model is detected + embedding_models = [m for m in result['models'] if m['type'] == 'embedding'] + assert len(embedding_models) == 1 + + @pytest.mark.asyncio + async def test_scan_models_enriches_llm_abilities_and_context_length(self): + """Scanned LLM models get LiteLLM-derived abilities and context length.""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.openai.com/v1', + 'timeout': 60, + }, + ) + requester._supports_function_calling = Mock(side_effect=lambda model_id: model_id == 'gpt-4o') + requester._supports_vision = Mock(side_effect=lambda model_id: model_id == 'gpt-4o') + requester._safe_context_length = Mock(side_effect=lambda model_id: 128000 if model_id == 'gpt-4o' else None) + + mock_response = Mock() + mock_response.json = Mock( + return_value={ + 'data': [ + {'id': 'gpt-4o'}, + {'id': 'text-embedding-3-small'}, + {'id': 'bge-reranker-v2'}, + ] + } + ) + mock_response.raise_for_status = Mock() + + with patch('httpx.AsyncClient') as mock_client: + mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock()) + mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response) + + result = await requester.scan_models(api_key='test-key') + + by_id = {model['id']: model for model in result['models']} + assert by_id['gpt-4o']['abilities'] == ['func_call', 'vision'] + assert by_id['gpt-4o']['context_length'] == 128000 + assert by_id['text-embedding-3-small']['type'] == 'embedding' + assert by_id['bge-reranker-v2']['type'] == 'rerank' + + @pytest.mark.asyncio + async def test_scan_models_prefers_context_length_from_provider_payload(self): + """Provider-supplied context_length is preserved before LiteLLM metadata fallback.""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.moonshot.cn/v1', + 'timeout': 60, + }, + ) + requester._supports_function_calling = Mock(return_value=False) + requester._supports_vision = Mock(return_value=False) + requester._safe_context_length = Mock(return_value=None) + + mock_response = Mock() + mock_response.json = Mock( + return_value={ + 'data': [ + {'id': 'moonshot-v1-128k', 'context_length': 131072}, + ] + } + ) + mock_response.raise_for_status = Mock() + + with patch('httpx.AsyncClient') as mock_client: + mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock()) + mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response) + + result = await requester.scan_models(api_key='test-key') + + assert result['models'][0]['context_length'] == 131072 + requester._safe_context_length.assert_not_called() + + def test_safe_context_length_tries_moonshot_metadata_alias(self): + """OpenAI-compatible Moonshot endpoints still use Moonshot metadata for context windows.""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.moonshot.cn/v1', + 'custom_llm_provider': 'openai', + }, + ) + + with patch.object(litellmchat.litellm, 'get_max_tokens') as mock_get_max_tokens: + mock_get_max_tokens.side_effect = lambda model: 131072 if model == 'moonshot/moonshot-v1-128k' else None + + assert requester._safe_context_length('moonshot-v1-128k') == 131072 + + def test_litellm_bool_helper_tries_moonshot_metadata_alias(self): + """OpenAI-compatible Moonshot endpoints still use Moonshot metadata for abilities.""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.moonshot.cn/v1', + 'custom_llm_provider': 'openai', + }, + ) + + with patch.object(litellmchat.litellm, 'supports_function_calling') as mock_supports_function_calling: + mock_supports_function_calling.side_effect = ( + lambda model, custom_llm_provider=None: model == 'moonshot/kimi-k2.6' + and custom_llm_provider is None + ) + + assert requester._supports_function_calling('kimi-k2.6') is True + + @pytest.mark.asyncio + async def test_scan_models_uses_provider_payload_for_vision_ability(self): + """Provider-supplied vision support is used when scanning models.""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.moonshot.cn/v1', + 'timeout': 60, + }, + ) + requester._supports_function_calling = Mock(return_value=False) + requester._supports_vision = Mock(return_value=False) + requester._safe_context_length = Mock(return_value=None) + + mock_response = Mock() + mock_response.json = Mock( + return_value={ + 'data': [ + { + 'id': 'moonshot-v1-128k-vision-preview', + 'supports_image_in': True, + }, + ] + } + ) + mock_response.raise_for_status = Mock() + + with patch('httpx.AsyncClient') as mock_client: + mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock()) + mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response) + + result = await requester.scan_models(api_key='test-key') + + assert result['models'][0]['abilities'] == ['vision'] + + def test_safe_context_length_falls_back_for_deepseek_v4_models(self): + """DeepSeek V4 API ids have a known 1M context even before LiteLLM maps them.""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.deepseek.com', + 'custom_llm_provider': 'deepseek', + }, + ) + + with patch.object(litellmchat.litellm, 'get_max_tokens', side_effect=Exception('not mapped')): + assert requester._safe_context_length('deepseek-v4-pro') == 1_000_000 + assert requester._safe_context_length('deepseek-v4-flash') == 1_000_000 + + @pytest.mark.asyncio + async def test_scan_models_no_base_url(self): + """Test scan_models without base_url raises error""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': '', + }, + ) + + with pytest.raises(errors.RequesterError) as exc_info: + await requester.scan_models() + + assert 'Base URL required' in str(exc_info.value) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/unit_tests/provider/test_localagent_sandbox_exec.py b/tests/unit_tests/provider/test_localagent_sandbox_exec.py index daa4eb2d..08b4c540 100644 --- a/tests/unit_tests/provider/test_localagent_sandbox_exec.py +++ b/tests/unit_tests/provider/test_localagent_sandbox_exec.py @@ -10,7 +10,7 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query import langbot_plugin.api.entities.builtin.provider.message as provider_message import langbot_plugin.api.entities.builtin.provider.session as provider_session -from langbot.pkg.provider.runners.localagent import LocalAgentRunner +from langbot.pkg.provider.runners.localagent import LocalAgentRunner, _StreamAccumulator class RecordingProvider: @@ -124,6 +124,45 @@ def make_query() -> pipeline_query.Query: ) +def test_stream_accumulator_merges_fragmented_tool_call_arguments(): + accumulator = _StreamAccumulator(msg_sequence=1) + + assert ( + accumulator.add( + provider_message.MessageChunk( + role='assistant', + tool_calls=[ + provider_message.ToolCall( + id='call-1', + type='function', + function=provider_message.FunctionCall(name='exec', arguments='{"command":'), + ) + ], + ) + ) + is None + ) + + emitted = accumulator.add( + provider_message.MessageChunk( + role='assistant', + tool_calls=[ + provider_message.ToolCall( + id='call-1', + type='function', + function=provider_message.FunctionCall(name='exec', arguments='"pwd"}'), + ) + ], + is_final=True, + ) + ) + + assert emitted is not None + final_msg = accumulator.final_message() + assert final_msg.tool_calls[0].function.name == 'exec' + assert final_msg.tool_calls[0].function.arguments == '{"command":"pwd"}' + + @pytest.mark.asyncio async def test_localagent_uses_exec_for_exact_calculation(): provider = RecordingProvider() diff --git a/tests/unit_tests/provider/test_model_manager.py b/tests/unit_tests/provider/test_model_manager.py index b38a5d02..b6e82d3f 100644 --- a/tests/unit_tests/provider/test_model_manager.py +++ b/tests/unit_tests/provider/test_model_manager.py @@ -494,6 +494,7 @@ async def test_model_manager_init_temporary_runtime_llm_model(fake_requester_reg 'api_keys': ['temp-key'], }, 'abilities': ['func_call'], + 'context_length': 128000, 'extra_args': {'temperature': 0.5}, } @@ -501,6 +502,9 @@ async def test_model_manager_init_temporary_runtime_llm_model(fake_requester_reg assert runtime_model.model_entity.uuid == 'temp-model-uuid' assert runtime_model.model_entity.name == 'TempModel' + assert runtime_model.model_entity.context_length == 128000 + assert runtime_model.model_entity.extra_args == {'temperature': 0.5} + assert 'context_length' not in runtime_model.model_entity.extra_args assert runtime_model.provider.provider_entity.uuid == 'temp-provider-uuid' assert runtime_model.provider.token_mgr.tokens == ['temp-key'] @@ -785,4 +789,4 @@ def test_provider_not_found_error_str(): error = provider_errors.ProviderNotFoundError('test-provider') assert str(error) == 'Provider test-provider not found' - assert error.provider_name == 'test-provider' \ No newline at end of file + assert error.provider_name == 'test-provider' diff --git a/tests/unit_tests/provider/test_model_service.py b/tests/unit_tests/provider/test_model_service.py index 60ac658e..ba1657cd 100644 --- a/tests/unit_tests/provider/test_model_service.py +++ b/tests/unit_tests/provider/test_model_service.py @@ -16,8 +16,6 @@ from langbot.pkg.entity.persistence import model as persistence_model from langbot.pkg.pipeline.preproc.preproc import PreProcessor from langbot.pkg.provider.modelmgr import requester from langbot.pkg.provider.modelmgr.modelmgr import ModelManager -from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions -from langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl import ModelScopeChatCompletions from langbot.pkg.provider.modelmgr.token import TokenManager from langbot.pkg.provider.runners.localagent import LocalAgentRunner @@ -90,74 +88,6 @@ def test_token_manager_next_token_ignores_empty_token_list(): assert token_mgr.using_token_index == 0 -@pytest.mark.asyncio -async def test_openai_requester_initialize_uses_placeholder_api_key(monkeypatch): - captured_kwargs = {} - - def fake_client(**kwargs): - captured_kwargs.update(kwargs) - return SimpleNamespace(**kwargs) - - monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.chatcmpl.openai.AsyncClient', fake_client) - monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.chatcmpl.httpx.AsyncClient', fake_client) - - requester_inst = OpenAIChatCompletions(ap=SimpleNamespace(), config={}) - await requester_inst.initialize() - - assert captured_kwargs['api_key'] == OpenAIChatCompletions.init_api_key - - -@pytest.mark.asyncio -async def test_modelscope_requester_initialize_uses_placeholder_api_key(monkeypatch): - captured_kwargs = {} - - def fake_client(**kwargs): - captured_kwargs.update(kwargs) - return SimpleNamespace(**kwargs) - - monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl.openai.AsyncClient', fake_client) - monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl.httpx.AsyncClient', fake_client) - - requester_inst = ModelScopeChatCompletions(ap=SimpleNamespace(), config={}) - await requester_inst.initialize() - - assert captured_kwargs['api_key'] == ModelScopeChatCompletions.init_api_key - - -@pytest.mark.asyncio -async def test_openai_embedding_call_overrides_placeholder_api_key(): - captured_request = {} - - async def fake_create(**kwargs): - captured_request['api_key'] = fake_client.api_key - captured_request['kwargs'] = kwargs - return SimpleNamespace( - data=[SimpleNamespace(embedding=[0.1, 0.2])], - usage=SimpleNamespace(prompt_tokens=3, total_tokens=3), - ) - - fake_client = SimpleNamespace( - api_key=OpenAIChatCompletions.init_api_key, - embeddings=SimpleNamespace(create=fake_create), - ) - - requester_inst = OpenAIChatCompletions(ap=SimpleNamespace(), config={}) - requester_inst.client = fake_client - - embeddings, usage_info = await requester_inst.invoke_embedding( - model=requester.RuntimeEmbeddingModel( - model_entity=SimpleNamespace(name='text-embedding-3-small', extra_args={}), - provider=SimpleNamespace(token_mgr=TokenManager('provider-uuid', [' runtime-key ', '', 'runtime-key'])), - ), - input_text=['hello'], - ) - - assert captured_request['api_key'] == 'runtime-key' - assert captured_request['kwargs']['model'] == 'text-embedding-3-small' - assert embeddings == [[0.1, 0.2]] - assert usage_info == {'prompt_tokens': 3, 'total_tokens': 3} - - @pytest.mark.asyncio async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline(): from langbot.pkg.api.http.service.model import LLMModelsService diff --git a/tests/unit_tests/provider/test_tool_manager.py b/tests/unit_tests/provider/test_tool_manager.py index 8e8439f5..fbfcb13f 100644 --- a/tests/unit_tests/provider/test_tool_manager.py +++ b/tests/unit_tests/provider/test_tool_manager.py @@ -1,7 +1,7 @@ """Unit tests for ToolManager. Tests cover: -- Tool schema generation for OpenAI and Anthropic +- Tool schema generation for OpenAI/LiteLLM - Tool execution dispatch """ @@ -109,28 +109,6 @@ class TestToolManagerSchemaGeneration: assert tool2['type'] == 'function' assert tool2['function']['name'] == 'calculate' - @pytest.mark.asyncio - async def test_generate_tools_for_anthropic(self, mock_app, sample_tools): - """Test that generate_tools_for_anthropic produces correct schema.""" - toolmgr = get_toolmgr_module() - - manager = toolmgr.ToolManager(mock_app) - result = await manager.generate_tools_for_anthropic(sample_tools) - - assert len(result) == 2 - - # Verify first tool schema (Anthropic format) - tool1 = result[0] - assert tool1['name'] == 'get_weather' - assert tool1['description'] == 'Get current weather for a location' - assert 'input_schema' in tool1 - assert tool1['input_schema']['type'] == 'object' - - # Verify second tool schema - tool2 = result[1] - assert tool2['name'] == 'calculate' - assert 'input_schema' in tool2 - @pytest.mark.asyncio async def test_generate_tools_empty_list(self, mock_app): """Test that generating tools from empty list returns empty list.""" @@ -141,9 +119,6 @@ class TestToolManagerSchemaGeneration: openai_result = await manager.generate_tools_for_openai([]) assert openai_result == [] - anthropic_result = await manager.generate_tools_for_anthropic([]) - assert anthropic_result == [] - @pytest.mark.asyncio async def test_openai_schema_fields_complete(self, mock_app, sample_tools): """Test that OpenAI schema includes all required fields.""" @@ -161,20 +136,6 @@ class TestToolManagerSchemaGeneration: assert 'description' in func assert 'parameters' in func - @pytest.mark.asyncio - async def test_anthropic_schema_fields_complete(self, mock_app, sample_tools): - """Test that Anthropic schema includes all required fields.""" - toolmgr = get_toolmgr_module() - - manager = toolmgr.ToolManager(mock_app) - result = await manager.generate_tools_for_anthropic(sample_tools) - - for tool_schema in result: - assert 'name' in tool_schema - assert 'description' in tool_schema - assert 'input_schema' in tool_schema - - class TestToolManagerExecuteFuncCall: """Tests for execute_func_call method.""" diff --git a/uv.lock b/uv.lock index 9e5f8cfe..450328f5 100644 --- a/uv.lock +++ b/uv.lock @@ -1184,6 +1184,58 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bf/ee/aa015c5de8b0dc42a8e507eae8c2de5d1c0e068c896858fec6d502402ed6/ebooklib-0.20-py3-none-any.whl", hash = "sha256:fff5322517a37e31c972d27be7d982cc3928c16b3dcc5fd7e8f7c0f5d7bcf42b", size = 40995, upload-time = "2025-10-26T20:56:19.104Z" }, ] +[[package]] +name = "fastuuid" +version = "0.14.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/7d/d9daedf0f2ebcacd20d599928f8913e9d2aea1d56d2d355a93bfa2b611d7/fastuuid-0.14.0.tar.gz", hash = "sha256:178947fc2f995b38497a74172adee64fdeb8b7ec18f2a5934d037641ba265d26", size = 18232, upload-time = "2025-10-19T22:19:22.402Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/f3/12481bda4e5b6d3e698fbf525df4443cc7dce746f246b86b6fcb2fba1844/fastuuid-0.14.0-cp311-cp311-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:73946cb950c8caf65127d4e9a325e2b6be0442a224fd51ba3b6ac44e1912ce34", size = 516386, upload-time = "2025-10-19T22:42:40.176Z" }, + { url = "https://files.pythonhosted.org/packages/59/19/2fc58a1446e4d72b655648eb0879b04e88ed6fa70d474efcf550f640f6ec/fastuuid-0.14.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:12ac85024637586a5b69645e7ed986f7535106ed3013640a393a03e461740cb7", size = 264569, upload-time = "2025-10-19T22:25:50.977Z" }, + { url = "https://files.pythonhosted.org/packages/78/29/3c74756e5b02c40cfcc8b1d8b5bac4edbd532b55917a6bcc9113550e99d1/fastuuid-0.14.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:05a8dde1f395e0c9b4be515b7a521403d1e8349443e7641761af07c7ad1624b1", size = 254366, upload-time = "2025-10-19T22:29:49.166Z" }, + { url = "https://files.pythonhosted.org/packages/52/96/d761da3fccfa84f0f353ce6e3eb8b7f76b3aa21fd25e1b00a19f9c80a063/fastuuid-0.14.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09378a05020e3e4883dfdab438926f31fea15fd17604908f3d39cbeb22a0b4dc", size = 278978, upload-time = "2025-10-19T22:35:41.306Z" }, + { url = "https://files.pythonhosted.org/packages/fc/c2/f84c90167cc7765cb82b3ff7808057608b21c14a38531845d933a4637307/fastuuid-0.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbb0c4b15d66b435d2538f3827f05e44e2baafcc003dd7d8472dc67807ab8fd8", size = 279692, upload-time = "2025-10-19T22:25:36.997Z" }, + { url = "https://files.pythonhosted.org/packages/af/7b/4bacd03897b88c12348e7bd77943bac32ccf80ff98100598fcff74f75f2e/fastuuid-0.14.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cd5a7f648d4365b41dbf0e38fe8da4884e57bed4e77c83598e076ac0c93995e7", size = 303384, upload-time = "2025-10-19T22:29:46.578Z" }, + { url = "https://files.pythonhosted.org/packages/c0/a2/584f2c29641df8bd810d00c1f21d408c12e9ad0c0dafdb8b7b29e5ddf787/fastuuid-0.14.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c0a94245afae4d7af8c43b3159d5e3934c53f47140be0be624b96acd672ceb73", size = 460921, upload-time = "2025-10-19T22:36:42.006Z" }, + { url = "https://files.pythonhosted.org/packages/24/68/c6b77443bb7764c760e211002c8638c0c7cce11cb584927e723215ba1398/fastuuid-0.14.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:2b29e23c97e77c3a9514d70ce343571e469098ac7f5a269320a0f0b3e193ab36", size = 480575, upload-time = "2025-10-19T22:28:18.975Z" }, + { url = "https://files.pythonhosted.org/packages/5a/87/93f553111b33f9bb83145be12868c3c475bf8ea87c107063d01377cc0e8e/fastuuid-0.14.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1e690d48f923c253f28151b3a6b4e335f2b06bf669c68a02665bc150b7839e94", size = 452317, upload-time = "2025-10-19T22:25:32.75Z" }, + { url = "https://files.pythonhosted.org/packages/9e/8c/a04d486ca55b5abb7eaa65b39df8d891b7b1635b22db2163734dc273579a/fastuuid-0.14.0-cp311-cp311-win32.whl", hash = "sha256:a6f46790d59ab38c6aa0e35c681c0484b50dc0acf9e2679c005d61e019313c24", size = 154804, upload-time = "2025-10-19T22:24:15.615Z" }, + { url = "https://files.pythonhosted.org/packages/9c/b2/2d40bf00820de94b9280366a122cbaa60090c8cf59e89ac3938cf5d75895/fastuuid-0.14.0-cp311-cp311-win_amd64.whl", hash = "sha256:e150eab56c95dc9e3fefc234a0eedb342fac433dacc273cd4d150a5b0871e1fa", size = 156099, upload-time = "2025-10-19T22:24:31.646Z" }, + { url = "https://files.pythonhosted.org/packages/02/a2/e78fcc5df65467f0d207661b7ef86c5b7ac62eea337c0c0fcedbeee6fb13/fastuuid-0.14.0-cp312-cp312-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:77e94728324b63660ebf8adb27055e92d2e4611645bf12ed9d88d30486471d0a", size = 510164, upload-time = "2025-10-19T22:31:45.635Z" }, + { url = "https://files.pythonhosted.org/packages/2b/b3/c846f933f22f581f558ee63f81f29fa924acd971ce903dab1a9b6701816e/fastuuid-0.14.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:caa1f14d2102cb8d353096bc6ef6c13b2c81f347e6ab9d6fbd48b9dea41c153d", size = 261837, upload-time = "2025-10-19T22:38:38.53Z" }, + { url = "https://files.pythonhosted.org/packages/54/ea/682551030f8c4fa9a769d9825570ad28c0c71e30cf34020b85c1f7ee7382/fastuuid-0.14.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d23ef06f9e67163be38cece704170486715b177f6baae338110983f99a72c070", size = 251370, upload-time = "2025-10-19T22:40:26.07Z" }, + { url = "https://files.pythonhosted.org/packages/14/dd/5927f0a523d8e6a76b70968e6004966ee7df30322f5fc9b6cdfb0276646a/fastuuid-0.14.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c9ec605ace243b6dbe3bd27ebdd5d33b00d8d1d3f580b39fdd15cd96fd71796", size = 277766, upload-time = "2025-10-19T22:37:23.779Z" }, + { url = "https://files.pythonhosted.org/packages/16/6e/c0fb547eef61293153348f12e0f75a06abb322664b34a1573a7760501336/fastuuid-0.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:808527f2407f58a76c916d6aa15d58692a4a019fdf8d4c32ac7ff303b7d7af09", size = 278105, upload-time = "2025-10-19T22:26:56.821Z" }, + { url = "https://files.pythonhosted.org/packages/2d/b1/b9c75e03b768f61cf2e84ee193dc18601aeaf89a4684b20f2f0e9f52b62c/fastuuid-0.14.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2fb3c0d7fef6674bbeacdd6dbd386924a7b60b26de849266d1ff6602937675c8", size = 301564, upload-time = "2025-10-19T22:30:31.604Z" }, + { url = "https://files.pythonhosted.org/packages/fc/fa/f7395fdac07c7a54f18f801744573707321ca0cee082e638e36452355a9d/fastuuid-0.14.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab3f5d36e4393e628a4df337c2c039069344db5f4b9d2a3c9cea48284f1dd741", size = 459659, upload-time = "2025-10-19T22:31:32.341Z" }, + { url = "https://files.pythonhosted.org/packages/66/49/c9fd06a4a0b1f0f048aacb6599e7d96e5d6bc6fa680ed0d46bf111929d1b/fastuuid-0.14.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:b9a0ca4f03b7e0b01425281ffd44e99d360e15c895f1907ca105854ed85e2057", size = 478430, upload-time = "2025-10-19T22:26:22.962Z" }, + { url = "https://files.pythonhosted.org/packages/be/9c/909e8c95b494e8e140e8be6165d5fc3f61fdc46198c1554df7b3e1764471/fastuuid-0.14.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3acdf655684cc09e60fb7e4cf524e8f42ea760031945aa8086c7eae2eeeabeb8", size = 450894, upload-time = "2025-10-19T22:27:01.647Z" }, + { url = "https://files.pythonhosted.org/packages/90/eb/d29d17521976e673c55ef7f210d4cdd72091a9ec6755d0fd4710d9b3c871/fastuuid-0.14.0-cp312-cp312-win32.whl", hash = "sha256:9579618be6280700ae36ac42c3efd157049fe4dd40ca49b021280481c78c3176", size = 154374, upload-time = "2025-10-19T22:29:19.879Z" }, + { url = "https://files.pythonhosted.org/packages/cc/fc/f5c799a6ea6d877faec0472d0b27c079b47c86b1cdc577720a5386483b36/fastuuid-0.14.0-cp312-cp312-win_amd64.whl", hash = "sha256:d9e4332dc4ba054434a9594cbfaf7823b57993d7d8e7267831c3e059857cf397", size = 156550, upload-time = "2025-10-19T22:27:49.658Z" }, + { url = "https://files.pythonhosted.org/packages/a5/83/ae12dd39b9a39b55d7f90abb8971f1a5f3c321fd72d5aa83f90dc67fe9ed/fastuuid-0.14.0-cp313-cp313-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:77a09cb7427e7af74c594e409f7731a0cf887221de2f698e1ca0ebf0f3139021", size = 510720, upload-time = "2025-10-19T22:42:34.633Z" }, + { url = "https://files.pythonhosted.org/packages/53/b0/a4b03ff5d00f563cc7546b933c28cb3f2a07344b2aec5834e874f7d44143/fastuuid-0.14.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:9bd57289daf7b153bfa3e8013446aa144ce5e8c825e9e366d455155ede5ea2dc", size = 262024, upload-time = "2025-10-19T22:30:25.482Z" }, + { url = "https://files.pythonhosted.org/packages/9c/6d/64aee0a0f6a58eeabadd582e55d0d7d70258ffdd01d093b30c53d668303b/fastuuid-0.14.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ac60fc860cdf3c3f327374db87ab8e064c86566ca8c49d2e30df15eda1b0c2d5", size = 251679, upload-time = "2025-10-19T22:36:14.096Z" }, + { url = "https://files.pythonhosted.org/packages/60/f5/a7e9cda8369e4f7919d36552db9b2ae21db7915083bc6336f1b0082c8b2e/fastuuid-0.14.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab32f74bd56565b186f036e33129da77db8be09178cd2f5206a5d4035fb2a23f", size = 277862, upload-time = "2025-10-19T22:36:23.302Z" }, + { url = "https://files.pythonhosted.org/packages/f0/d3/8ce11827c783affffd5bd4d6378b28eb6cc6d2ddf41474006b8d62e7448e/fastuuid-0.14.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33e678459cf4addaedd9936bbb038e35b3f6b2061330fd8f2f6a1d80414c0f87", size = 278278, upload-time = "2025-10-19T22:29:43.809Z" }, + { url = "https://files.pythonhosted.org/packages/a2/51/680fb6352d0bbade04036da46264a8001f74b7484e2fd1f4da9e3db1c666/fastuuid-0.14.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1e3cc56742f76cd25ecb98e4b82a25f978ccffba02e4bdce8aba857b6d85d87b", size = 301788, upload-time = "2025-10-19T22:36:06.825Z" }, + { url = "https://files.pythonhosted.org/packages/fa/7c/2014b5785bd8ebdab04ec857635ebd84d5ee4950186a577db9eff0fb8ff6/fastuuid-0.14.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:cb9a030f609194b679e1660f7e32733b7a0f332d519c5d5a6a0a580991290022", size = 459819, upload-time = "2025-10-19T22:35:31.623Z" }, + { url = "https://files.pythonhosted.org/packages/01/d2/524d4ceeba9160e7a9bc2ea3e8f4ccf1ad78f3bde34090ca0c51f09a5e91/fastuuid-0.14.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:09098762aad4f8da3a888eb9ae01c84430c907a297b97166b8abc07b640f2995", size = 478546, upload-time = "2025-10-19T22:26:03.023Z" }, + { url = "https://files.pythonhosted.org/packages/bc/17/354d04951ce114bf4afc78e27a18cfbd6ee319ab1829c2d5fb5e94063ac6/fastuuid-0.14.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1383fff584fa249b16329a059c68ad45d030d5a4b70fb7c73a08d98fd53bcdab", size = 450921, upload-time = "2025-10-19T22:31:02.151Z" }, + { url = "https://files.pythonhosted.org/packages/fb/be/d7be8670151d16d88f15bb121c5b66cdb5ea6a0c2a362d0dcf30276ade53/fastuuid-0.14.0-cp313-cp313-win32.whl", hash = "sha256:a0809f8cc5731c066c909047f9a314d5f536c871a7a22e815cc4967c110ac9ad", size = 154559, upload-time = "2025-10-19T22:36:36.011Z" }, + { url = "https://files.pythonhosted.org/packages/22/1d/5573ef3624ceb7abf4a46073d3554e37191c868abc3aecd5289a72f9810a/fastuuid-0.14.0-cp313-cp313-win_amd64.whl", hash = "sha256:0df14e92e7ad3276327631c9e7cec09e32572ce82089c55cb1bb8df71cf394ed", size = 156539, upload-time = "2025-10-19T22:33:35.898Z" }, + { url = "https://files.pythonhosted.org/packages/16/c9/8c7660d1fe3862e3f8acabd9be7fc9ad71eb270f1c65cce9a2b7a31329ab/fastuuid-0.14.0-cp314-cp314-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:b852a870a61cfc26c884af205d502881a2e59cc07076b60ab4a951cc0c94d1ad", size = 510600, upload-time = "2025-10-19T22:43:44.17Z" }, + { url = "https://files.pythonhosted.org/packages/4c/f4/a989c82f9a90d0ad995aa957b3e572ebef163c5299823b4027986f133dfb/fastuuid-0.14.0-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:c7502d6f54cd08024c3ea9b3514e2d6f190feb2f46e6dbcd3747882264bb5f7b", size = 262069, upload-time = "2025-10-19T22:43:38.38Z" }, + { url = "https://files.pythonhosted.org/packages/da/6c/a1a24f73574ac995482b1326cf7ab41301af0fabaa3e37eeb6b3df00e6e2/fastuuid-0.14.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1ca61b592120cf314cfd66e662a5b54a578c5a15b26305e1b8b618a6f22df714", size = 251543, upload-time = "2025-10-19T22:32:22.537Z" }, + { url = "https://files.pythonhosted.org/packages/1a/20/2a9b59185ba7a6c7b37808431477c2d739fcbdabbf63e00243e37bd6bf49/fastuuid-0.14.0-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa75b6657ec129d0abded3bec745e6f7ab642e6dba3a5272a68247e85f5f316f", size = 277798, upload-time = "2025-10-19T22:33:53.821Z" }, + { url = "https://files.pythonhosted.org/packages/ef/33/4105ca574f6ded0af6a797d39add041bcfb468a1255fbbe82fcb6f592da2/fastuuid-0.14.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8a0dfea3972200f72d4c7df02c8ac70bad1bb4c58d7e0ec1e6f341679073a7f", size = 278283, upload-time = "2025-10-19T22:29:02.812Z" }, + { url = "https://files.pythonhosted.org/packages/fe/8c/fca59f8e21c4deb013f574eae05723737ddb1d2937ce87cb2a5d20992dc3/fastuuid-0.14.0-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1bf539a7a95f35b419f9ad105d5a8a35036df35fdafae48fb2fd2e5f318f0d75", size = 301627, upload-time = "2025-10-19T22:35:54.985Z" }, + { url = "https://files.pythonhosted.org/packages/cb/e2/f78c271b909c034d429218f2798ca4e89eeda7983f4257d7865976ddbb6c/fastuuid-0.14.0-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:9a133bf9cc78fdbd1179cb58a59ad0100aa32d8675508150f3658814aeefeaa4", size = 459778, upload-time = "2025-10-19T22:28:00.999Z" }, + { url = "https://files.pythonhosted.org/packages/1e/f0/5ff209d865897667a2ff3e7a572267a9ced8f7313919f6d6043aed8b1caa/fastuuid-0.14.0-cp314-cp314-musllinux_1_1_i686.whl", hash = "sha256:f54d5b36c56a2d5e1a31e73b950b28a0d83eb0c37b91d10408875a5a29494bad", size = 478605, upload-time = "2025-10-19T22:36:21.764Z" }, + { url = "https://files.pythonhosted.org/packages/e0/c8/2ce1c78f983a2c4987ea865d9516dbdfb141a120fd3abb977ae6f02ba7ca/fastuuid-0.14.0-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:ec27778c6ca3393ef662e2762dba8af13f4ec1aaa32d08d77f71f2a70ae9feb8", size = 450837, upload-time = "2025-10-19T22:34:37.178Z" }, + { url = "https://files.pythonhosted.org/packages/df/60/dad662ec9a33b4a5fe44f60699258da64172c39bd041da2994422cdc40fe/fastuuid-0.14.0-cp314-cp314-win32.whl", hash = "sha256:e23fc6a83f112de4be0cc1990e5b127c27663ae43f866353166f87df58e73d06", size = 154532, upload-time = "2025-10-19T22:35:18.217Z" }, + { url = "https://files.pythonhosted.org/packages/1f/f6/da4db31001e854025ffd26bc9ba0740a9cbba2c3259695f7c5834908b336/fastuuid-0.14.0-cp314-cp314-win_amd64.whl", hash = "sha256:df61342889d0f5e7a32f7284e55ef95103f2110fee433c2ae7c2c0956d76ac8a", size = 156457, upload-time = "2025-10-19T22:33:44.579Z" }, +] + [[package]] name = "filelock" version = "3.20.3" @@ -1949,6 +2001,7 @@ dependencies = [ { name = "langsmith" }, { name = "lark-oapi" }, { name = "line-bot-sdk" }, + { name = "litellm" }, { name = "mako" }, { name = "markdown" }, { name = "matrix-nio" }, @@ -2036,6 +2089,7 @@ requires-dist = [ { name = "langsmith", specifier = ">=0.8.0" }, { name = "lark-oapi", specifier = ">=1.5.5" }, { name = "line-bot-sdk", specifier = ">=3.19.0" }, + { name = "litellm", specifier = ">=1.0.0" }, { name = "mako", specifier = ">=1.3.12" }, { name = "markdown", specifier = ">=3.6" }, { name = "matrix-nio", specifier = ">=0.25.2" }, @@ -2363,6 +2417,29 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/1e/b832de447dee8b582cac175871d2f6c3d5077cc56d5575cadba1fd1cccfa/linkify_it_py-2.0.3-py3-none-any.whl", hash = "sha256:6bcbc417b0ac14323382aef5c5192c0075bf8a9d6b41820a2b66371eac6b6d79", size = 19820, upload-time = "2024-02-04T14:48:02.496Z" }, ] +[[package]] +name = "litellm" +version = "1.88.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "click" }, + { name = "fastuuid" }, + { name = "httpx" }, + { name = "importlib-metadata" }, + { name = "jinja2" }, + { name = "jsonschema" }, + { name = "openai" }, + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "tiktoken" }, + { name = "tokenizers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/16/ea/f99ececb7f22703fe120f1d8be9ffb749ec9453fbbbbbebc0d6a6b4d7864/litellm-1.88.1.tar.gz", hash = "sha256:89c6b74cc7912d6365793006ff951c0450fe847625008dfe49de8a7dc4529aa5", size = 13885969, upload-time = "2026-06-09T01:06:25.192Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/9a/8f8909201b4bebaf96498c09226f6baa8540086a4c4188ad57d7dfbd97c1/litellm-1.88.1-py3-none-any.whl", hash = "sha256:369b84e57d9426582ddc35e731956ddb6618cda97cc44e4e4d2dfa75982a6e3a", size = 15276206, upload-time = "2026-06-09T01:06:16.72Z" }, +] + [[package]] name = "logbook" version = "1.9.2" @@ -3302,7 +3379,7 @@ wheels = [ [[package]] name = "openai" -version = "2.16.0" +version = "2.41.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -3314,9 +3391,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b1/6c/e4c964fcf1d527fdf4739e7cc940c60075a4114d50d03871d5d5b1e13a88/openai-2.16.0.tar.gz", hash = "sha256:42eaa22ca0d8ded4367a77374104d7a2feafee5bd60a107c3c11b5243a11cd12", size = 629649, upload-time = "2026-01-27T23:28:02.579Z" } +sdist = { url = "https://files.pythonhosted.org/packages/40/36/4c926a91554483977608951360c18c2e911592785eb87a6437813f6123f7/openai-2.41.1.tar.gz", hash = "sha256:23d617a0432457ad844973bee8f540be9da90894f7c5686852d2d365da058f57", size = 783584, upload-time = "2026-06-10T16:10:37.667Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/16/83/0315bf2cfd75a2ce8a7e54188e9456c60cec6c0cf66728ed07bd9859ff26/openai-2.16.0-py3-none-any.whl", hash = "sha256:5f46643a8f42899a84e80c38838135d7038e7718333ce61396994f887b09a59b", size = 1068612, upload-time = "2026-01-27T23:28:00.356Z" }, + { url = "https://files.pythonhosted.org/packages/20/74/925d7b3892927e9804aaf58d374a45dc28e4420ff90e992272b77286343e/openai-2.41.1-py3-none-any.whl", hash = "sha256:a939565f350cb7443cb843b801b88c716ac8024b492fb94ca269d5f6b1bbefd6", size = 1353380, upload-time = "2026-06-10T16:10:35.756Z" }, ] [[package]] diff --git a/web/src/app/home/components/models-dialog/ModelsDialog.tsx b/web/src/app/home/components/models-dialog/ModelsDialog.tsx index 16c6663d..ccb03b3c 100644 --- a/web/src/app/home/components/models-dialog/ModelsDialog.tsx +++ b/web/src/app/home/components/models-dialog/ModelsDialog.tsx @@ -64,6 +64,17 @@ function convertExtraArgsToObject( return obj; } +function parseContextLength( + value: number | null | undefined, + invalidMessage: string, +): number | null { + if (value === undefined || value === null) return null; + if (!Number.isInteger(value) || value <= 0) { + throw new Error(invalidMessage); + } + return value; +} + export default function ModelsDialog({ open, onOpenChange, @@ -91,6 +102,12 @@ export default function ModelsDialog({ null, ); + // Map of requester name -> support_type[] (from requester manifests), + // used to restrict which model-type tabs are shown when adding models. + const [requesterSupportTypes, setRequesterSupportTypes] = useState< + Record + >({}); + // Popover states const [addModelPopoverOpen, setAddModelPopoverOpen] = useState( null, @@ -122,6 +139,7 @@ export default function ModelsDialog({ if (open) { loadUserInfo(); loadProviders(); + loadRequesterSupportTypes(); } }, [open]); @@ -161,6 +179,19 @@ export default function ModelsDialog({ } } + async function loadRequesterSupportTypes() { + try { + const resp = await httpClient.getProviderRequesters(); + const map: Record = {}; + for (const r of resp.requesters) { + map[r.name] = r.spec?.support_type ?? []; + } + setRequesterSupportTypes(map); + } catch (err) { + console.error('Failed to load requester support types', err); + } + } + async function loadProviderModels(providerUuid: string, silent = false) { if (loadingProviders.has(providerUuid)) return; @@ -254,6 +285,7 @@ export default function ModelsDialog({ name: string, abilities: string[], extraArgs: ExtraArg[], + contextLength?: number | null, ) { if (!name.trim()) { toast.error(t('models.modelNameRequired')); @@ -268,6 +300,10 @@ export default function ModelsDialog({ name, provider_uuid: providerUuid, abilities, + context_length: parseContextLength( + contextLength, + t('models.contextLengthInvalid'), + ), extra_args: extraArgsObj, } as never); } else if (modelType === 'embedding') { @@ -325,6 +361,7 @@ export default function ModelsDialog({ name: item.model.name, provider_uuid: providerUuid, abilities: item.abilities, + context_length: item.model.context_length ?? null, extra_args: {}, } as never); } else if (effectiveType === 'embedding') { @@ -361,6 +398,7 @@ export default function ModelsDialog({ name: string, abilities: string[], extraArgs: ExtraArg[], + contextLength?: number | null, ) { if (!name.trim()) { toast.error(t('models.modelNameRequired')); @@ -375,6 +413,10 @@ export default function ModelsDialog({ name, provider_uuid: providerUuid, abilities, + context_length: parseContextLength( + contextLength, + t('models.contextLengthInvalid'), + ), extra_args: extraArgsObj, } as never); } else if (modelType === 'embedding') { @@ -495,6 +537,7 @@ export default function ModelsDialog({ key={provider.uuid} provider={provider} isLangBotModels={isLangBotModels} + supportTypes={requesterSupportTypes[provider.requester]} isExpanded={expandedProviders.has(provider.uuid)} isLoading={loadingProviders.has(provider.uuid)} models={providerModels[provider.uuid]} @@ -509,8 +552,15 @@ export default function ModelsDialog({ onSpaceLogin={handleSpaceLogin} onOpenAddModel={() => setAddModelPopoverOpen(provider.uuid)} onCloseAddModel={() => setAddModelPopoverOpen(null)} - onAddModel={(modelType, name, abilities, extraArgs) => - handleAddModel(provider.uuid, modelType, name, abilities, extraArgs) + onAddModel={(modelType, name, abilities, extraArgs, contextLength) => + handleAddModel( + provider.uuid, + modelType, + name, + abilities, + extraArgs, + contextLength, + ) } onScanModels={(modelType) => handleScanModels(provider.uuid, modelType)} onAddScannedModels={(modelType, models) => @@ -518,7 +568,14 @@ export default function ModelsDialog({ } onOpenEditModel={(modelId) => setEditModelPopoverOpen(modelId)} onCloseEditModel={() => setEditModelPopoverOpen(null)} - onUpdateModel={(modelId, modelType, name, abilities, extraArgs) => + onUpdateModel={( + modelId, + modelType, + name, + abilities, + extraArgs, + contextLength, + ) => handleUpdateModel( provider.uuid, modelId, @@ -526,6 +583,7 @@ export default function ModelsDialog({ name, abilities, extraArgs, + contextLength, ) } onOpenDeleteConfirm={(modelId) => setDeleteConfirmOpen(modelId)} diff --git a/web/src/app/home/components/models-dialog/component/provider-form/ProviderForm.tsx b/web/src/app/home/components/models-dialog/component/provider-form/ProviderForm.tsx index c596037a..ae20bb25 100644 --- a/web/src/app/home/components/models-dialog/component/provider-form/ProviderForm.tsx +++ b/web/src/app/home/components/models-dialog/component/provider-form/ProviderForm.tsx @@ -1,4 +1,4 @@ -import { useEffect, useState } from 'react'; +import { useEffect, useState, useRef, useCallback } from 'react'; import { httpClient } from '@/app/infra/http/HttpClient'; import { zodResolver } from '@hookform/resolvers/zod'; @@ -16,19 +16,12 @@ import { FormMessage, } from '@/components/ui/form'; import { Input } from '@/components/ui/input'; -import { - Select, - SelectContent, - SelectGroup, - SelectItem, - SelectLabel, - SelectTrigger, - SelectValue, -} from '@/components/ui/select'; import { DialogFooter } from '@/components/ui/dialog'; import { toast } from 'sonner'; import { extractI18nObject } from '@/i18n/I18nProvider'; import { CustomApiError } from '@/app/infra/entities/common'; +import { cn } from '@/lib/utils'; +import { Check, ChevronDown, Search } from 'lucide-react'; const getFormSchema = (t: (key: string) => string) => z.object({ @@ -61,6 +54,7 @@ export default function ProviderForm({ api_key: '', }, }); + const { setValue } = form; const [requesterList, setRequesterList] = useState< { @@ -69,20 +63,15 @@ export default function ProviderForm({ category: string; defaultUrl: string; description: string; + alias: string; }[] >([]); + const [searchQuery, setSearchQuery] = useState(''); + const [isOpen, setIsOpen] = useState(false); + const dropdownRef = useRef(null); + const searchInputRef = useRef(null); - useEffect(() => { - async function init() { - await loadRequesters(); - if (providerId) { - await loadProvider(providerId); - } - } - init(); - }, [providerId]); - - async function loadRequesters() { + const loadRequesters = useCallback(async () => { const resp = await httpClient.getProviderRequesters(); setRequesterList( resp.requesters @@ -96,19 +85,82 @@ export default function ProviderForm({ .find((c) => c.name === 'base_url') ?.default?.toString() || '', description: extractI18nObject(item.description), + alias: item.spec.alias || '', })), ); - } + }, []); - async function loadProvider(id: string) { - const resp = await httpClient.getModelProvider(id); - const provider = resp.provider; + const loadProvider = useCallback( + async (id: string) => { + const resp = await httpClient.getModelProvider(id); + const provider = resp.provider; - form.setValue('name', provider.name); - form.setValue('requester', provider.requester); - form.setValue('base_url', provider.base_url); - form.setValue('api_key', provider.api_keys?.[0] || ''); - } + setValue('name', provider.name); + setValue('requester', provider.requester); + setValue('base_url', provider.base_url); + setValue('api_key', provider.api_keys?.[0] || ''); + }, + [setValue], + ); + + useEffect(() => { + async function init() { + await loadRequesters(); + if (providerId) { + await loadProvider(providerId); + } + } + init(); + }, [providerId, loadProvider, loadRequesters]); + + // Close dropdown when clicking outside + useEffect(() => { + function handleClickOutside(event: MouseEvent) { + if ( + dropdownRef.current && + !dropdownRef.current.contains(event.target as Node) + ) { + setIsOpen(false); + setSearchQuery(''); + } + } + document.addEventListener('mousedown', handleClickOutside); + return () => document.removeEventListener('mousedown', handleClickOutside); + }, []); + + // Focus search input when dropdown opens + useEffect(() => { + if (isOpen && searchInputRef.current) { + searchInputRef.current.focus(); + } + }, [isOpen]); + + // Filter requesters based on search query + const filteredRequesters = requesterList.filter( + (r) => + r.label.toLowerCase().includes(searchQuery.toLowerCase()) || + r.value.toLowerCase().includes(searchQuery.toLowerCase()) || + r.alias.toLowerCase().includes(searchQuery.toLowerCase()), + ); + + // Group filtered requesters by category + const groupedRequesters = { + builtin: filteredRequesters.filter((r) => r.category === 'builtin'), + manufacturer: filteredRequesters.filter( + (r) => r.category === 'manufacturer', + ), + maas: filteredRequesters.filter((r) => r.category === 'maas'), + 'self-hosted': filteredRequesters.filter( + (r) => r.category === 'self-hosted', + ), + }; + + const categoryLabels: Record = { + builtin: t('models.builtin'), + manufacturer: t('models.modelManufacturer'), + maas: t('models.aggregationPlatform'), + 'self-hosted': t('models.selfDeployed'), + }; async function handleFormSubmit(values: z.infer) { const data = { @@ -168,17 +220,16 @@ export default function ProviderForm({ {t('models.requester')} * - + + + + {/* Dropdown */} + {isOpen && ( +
+ {/* Search input */} +
+ + setSearchQuery(e.target.value)} + className="flex h-10 w-full rounded-md bg-transparent py-3 text-sm outline-none placeholder:text-muted-foreground" + /> +
+ + {/* Options list */} +
+ {Object.entries(groupedRequesters).map( + ([category, items]) => { + if (items.length === 0) return null; + return ( +
+
+ {categoryLabels[category]} +
+ {items.map((r) => ( + + ))} +
+ ); + }, + )} + {filteredRequesters.length === 0 && ( +
+ No results found. +
+ )} +
+
+ )} + {selectedRequester?.description && (

diff --git a/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx b/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx index c0899318..ddfc3a70 100644 --- a/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx +++ b/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx @@ -34,6 +34,7 @@ interface AddModelPopoverProps { isOpen: boolean; initialMode?: 'manual' | 'scan'; trigger?: React.ReactNode; + supportTypes?: string[]; onOpen: () => void; onClose: () => void; onAddModel: ( @@ -41,6 +42,7 @@ interface AddModelPopoverProps { name: string, abilities: string[], extraArgs: ExtraArg[], + contextLength?: number | null, ) => Promise; onScanModels: (modelType?: ModelType) => Promise; onAddScannedModels: ( @@ -63,6 +65,7 @@ export default function AddModelPopover({ isOpen, initialMode = 'manual', trigger, + supportTypes, onOpen, onClose, onAddModel, @@ -77,10 +80,26 @@ export default function AddModelPopover({ const { t } = useTranslation(); const prevIsOpenRef = useRef(false); - const [tab, setTab] = useState('llm'); + // Map manifest support_type values to UI tab values. + // Manifest uses 'text-embedding'; the UI tab uses 'embedding'. + const tabSupport: Record = { + llm: 'llm', + embedding: 'text-embedding', + rerank: 'rerank', + }; + const allTabs: ModelType[] = ['llm', 'embedding', 'rerank']; + // When supportTypes is undefined (unknown requester), show all tabs for + // backward compatibility. Otherwise only show supported tabs. + const visibleTabs: ModelType[] = supportTypes + ? allTabs.filter((tabKey) => supportTypes.includes(tabSupport[tabKey])) + : allTabs; + const defaultTab: ModelType = visibleTabs[0] ?? 'llm'; + + const [tab, setTab] = useState(defaultTab); const [mode, setMode] = useState<'manual' | 'scan'>('manual'); const [name, setName] = useState(''); const [abilities, setAbilities] = useState([]); + const [contextLength, setContextLength] = useState(''); const [extraArgs, setExtraArgs] = useState([]); const [scanLoading, setScanLoading] = useState(false); const [scannedModels, setScannedModels] = useState( @@ -94,10 +113,11 @@ export default function AddModelPopover({ useEffect(() => { const wasOpen = prevIsOpenRef.current; if (isOpen && !wasOpen) { - setTab('llm'); + setTab(defaultTab); setMode(initialMode); setName(''); setAbilities([]); + setContextLength(''); setExtraArgs([]); setScanLoading(false); setScannedModels([]); @@ -119,7 +139,11 @@ export default function AddModelPopover({ }, [tab, mode]); const handleAdd = async () => { - await onAddModel(tab, name, abilities, extraArgs); + const parsedContextLength = + tab === 'llm' && contextLength.trim() + ? Number(contextLength.trim()) + : null; + await onAddModel(tab, name, abilities, extraArgs, parsedContextLength); }; const handleTest = async () => { @@ -130,32 +154,6 @@ export default function AddModelPopover({ setScanLoading(true); try { const result = await onScanModels(trigger ? undefined : tab); - - const debugData = ( - result.debug?.response as { data?: Record[] } - )?.data; - if (Array.isArray(debugData)) { - const debugMap = new Map>(); - for (const item of debugData) { - if (typeof item?.id === 'string') { - debugMap.set(item.id, item); - } - } - for (const model of result.models) { - const debugItem = debugMap.get(model.id); - if (!debugItem) continue; - const features = debugItem.features as - | Record - | undefined; - const tools = features?.tools as Record | undefined; - if (tools?.function_calling === true) { - const nextAbilities = new Set(model.abilities || []); - nextAbilities.add('func_call'); - model.abilities = [...nextAbilities]; - } - } - } - setScannedModels(result.models); setSelectedScannedModels({}); } finally { @@ -279,20 +277,31 @@ export default function AddModelPopover({ className="flex flex-col min-h-0 flex-1" >

- {!(trigger && initialMode === 'scan') && ( - - - - {t('models.chat')} - - - - {t('models.embedding')} - - - - {t('models.rerank')} - + {!(trigger && initialMode === 'scan') && visibleTabs.length > 1 && ( + + {visibleTabs.includes('llm') && ( + + + {t('models.chat')} + + )} + {visibleTabs.includes('embedding') && ( + + + {t('models.embedding')} + + )} + {visibleTabs.includes('rerank') && ( + + + {t('models.rerank')} + + )} )}
@@ -344,6 +353,24 @@ export default function AddModelPopover({ )} + {tab === 'llm' && ( +
+ + setContextLength(e.target.value)} + /> +
+ )} + Promise; onTestModel: ( name: string, @@ -92,6 +93,11 @@ export default function ModelItem({ const [editAbilities, setEditAbilities] = useState( modelType === 'llm' ? (model as LLMModel).abilities || [] : [], ); + const [editContextLength, setEditContextLength] = useState( + modelType === 'llm' && (model as LLMModel).context_length + ? String((model as LLMModel).context_length) + : '', + ); const [editExtraArgs, setEditExtraArgs] = useState( convertExtraArgsToArray(model.extra_args), ); @@ -106,13 +112,27 @@ export default function ModelItem({ setEditAbilities( modelType === 'llm' ? (model as LLMModel).abilities || [] : [], ); + setEditContextLength( + modelType === 'llm' && (model as LLMModel).context_length + ? String((model as LLMModel).context_length) + : '', + ); setEditExtraArgs(convertExtraArgsToArray(model.extra_args)); onResetTestResult(); } }, [isEditOpen]); const handleSave = async () => { - await onUpdateModel(editName, editAbilities, editExtraArgs); + const parsedContextLength = + modelType === 'llm' && editContextLength.trim() + ? Number(editContextLength.trim()) + : null; + await onUpdateModel( + editName, + editAbilities, + editExtraArgs, + parsedContextLength, + ); }; const handleTest = async () => { @@ -287,6 +307,25 @@ export default function ModelItem({ )} + {modelType === 'llm' && ( +
+ + setEditContextLength(e.target.value)} + /> +
+ )} + Promise; onScanModels: (modelType?: ModelType) => Promise; onAddScannedModels: ( @@ -74,6 +76,7 @@ interface ProviderCardProps { name: string, abilities: string[], extraArgs: ExtraArg[], + contextLength?: number | null, ) => Promise; onOpenDeleteConfirm: (modelId: string) => void; onCloseDeleteConfirm: () => void; @@ -99,6 +102,7 @@ function maskApiKey(key: string): string { export default function ProviderCard({ provider, isLangBotModels = false, + supportTypes, isExpanded, isLoading, models, @@ -319,6 +323,7 @@ export default function ProviderCard({ addModelMode === 'manual' } initialMode="manual" + supportTypes={supportTypes} trigger={ + ))} + + +
+ + + + + + + + + + + formatNumber(Number(v))} + /> + formatNumber(Number(value))} + /> + + + + + + +
+ + + {/* Per-model breakdown */} +
+

+ {t('monitoring.tokens.byModel')} +

+
+ + + + + + + + + + + + + + {by_model.map((m) => { + const share = + summary.total_tokens > 0 + ? (m.total_tokens / summary.total_tokens) * 100 + : 0; + return ( + + + + + + + + + + ); + })} + +
+ {t('monitoring.tokens.model')} + + {t('monitoring.tokens.calls')} + + {t('monitoring.tokens.inputTokens')} + + {t('monitoring.tokens.outputTokens')} + + {t('monitoring.tokens.totalTokens')} + + {t('monitoring.tokens.avgPerCall')} + + {t('monitoring.tokens.avgLatency')} +
+
+ {m.model_name} +
+
+
+
+
+ {m.calls} + {m.error_calls > 0 && ( + + {' '} + ({m.error_calls}✕) + + )} + + {formatNumber(m.input_tokens)} + + {formatNumber(m.output_tokens)} + + {formatNumber(m.total_tokens)} + + {formatNumber(m.avg_tokens_per_call)} + + {m.avg_duration_ms}ms +
+
+
+ + ); +} diff --git a/web/src/app/home/monitoring/page.tsx b/web/src/app/home/monitoring/page.tsx index 5a75df0a..7dbe2e59 100644 --- a/web/src/app/home/monitoring/page.tsx +++ b/web/src/app/home/monitoring/page.tsx @@ -13,6 +13,7 @@ import { } from 'lucide-react'; import OverviewCards from './components/overview-cards/OverviewCards'; import MonitoringFilters from './components/filters/MonitoringFilters'; +import TokenMonitoring from './components/TokenMonitoring'; import { ExportDropdown } from './components/ExportDropdown'; import { useMonitoringFilters } from './hooks/useMonitoringFilters'; import { useMonitoringData } from './hooks/useMonitoringData'; @@ -319,6 +320,9 @@ function MonitoringPageContent() { {t('monitoring.tabs.modelCalls')} + + {t('monitoring.tabs.tokens')} + {t('monitoring.tabs.feedback')} @@ -668,6 +672,24 @@ function MonitoringPageContent() { + + 0 + ? filterState.selectedBots + : undefined + } + pipelineIds={ + filterState.selectedPipelines.length > 0 + ? filterState.selectedPipelines + : undefined + } + startTime={feedbackTimeRange.startTime} + endTime={feedbackTimeRange.endTime} + refreshKey={feedbackRefreshKey} + /> + +
{loading && ( diff --git a/web/src/app/infra/entities/api/index.ts b/web/src/app/infra/entities/api/index.ts index b9c3a90f..c9a5d01e 100644 --- a/web/src/app/infra/entities/api/index.ts +++ b/web/src/app/infra/entities/api/index.ts @@ -30,6 +30,8 @@ export interface Requester { spec: { config: IDynamicFormItemSchema[]; provider_category: string; + support_type?: string[]; + alias?: string; }; } @@ -96,6 +98,7 @@ export interface LLMModel { provider_uuid: string; provider?: ModelProvider; abilities?: string[]; + context_length?: number | null; extra_args?: object; } diff --git a/web/src/app/infra/http/BackendClient.ts b/web/src/app/infra/http/BackendClient.ts index 7b65897e..e9cdb51c 100644 --- a/web/src/app/infra/http/BackendClient.ts +++ b/web/src/app/infra/http/BackendClient.ts @@ -1224,6 +1224,68 @@ export class BackendClient extends BaseHttpClient { return this.get(`/api/v1/monitoring/overview?${queryParams.toString()}`); } + public getTokenStatistics(params: { + botId?: string[]; + pipelineId?: string[]; + startTime?: string; + endTime?: string; + bucket?: 'hour' | 'day'; + }): Promise<{ + summary: { + total_calls: number; + success_calls: number; + error_calls: number; + total_input_tokens: number; + total_output_tokens: number; + total_tokens: number; + total_cost: number; + avg_tokens_per_call: number; + avg_duration_ms: number; + avg_tokens_per_second: number; + zero_token_success_calls: number; + }; + by_model: Array<{ + model_name: string; + calls: number; + error_calls: number; + input_tokens: number; + output_tokens: number; + total_tokens: number; + cost: number; + avg_tokens_per_call: number; + avg_duration_ms: number; + }>; + timeseries: Array<{ + bucket: string; + input_tokens: number; + output_tokens: number; + total_tokens: number; + calls: number; + }>; + bucket: string; + }> { + const queryParams = new URLSearchParams(); + if (params.botId) { + params.botId.forEach((id) => queryParams.append('botId', id)); + } + if (params.pipelineId) { + params.pipelineId.forEach((id) => queryParams.append('pipelineId', id)); + } + if (params.startTime) { + queryParams.append('startTime', params.startTime); + } + if (params.endTime) { + queryParams.append('endTime', params.endTime); + } + if (params.bucket) { + queryParams.append('bucket', params.bucket); + } + + return this.get( + `/api/v1/monitoring/token-statistics?${queryParams.toString()}`, + ); + } + // ============ Survey API ============ public getSurveyPending(): Promise<{ survey: { diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index cd507222..ddf1ab43 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -201,6 +201,9 @@ const enUS = { selectModelAbilities: 'Select model abilities', visionAbility: 'Vision Ability', functionCallAbility: 'Function Call', + contextLength: 'Context Window', + contextLengthPlaceholder: 'Unknown', + contextLengthInvalid: 'Context window must be a positive integer', extraParameters: 'Extra Parameters', addParameter: 'Add Parameter', keyName: 'Key Name', @@ -258,6 +261,7 @@ const enUS = { selectProvider: 'Select Provider', requester: 'Provider Type', selectRequester: 'Select Provider Type', + searchProviders: 'Search providers...', langbotModelsDescription: 'Cloud models powered by LangBot Space', credits: 'Credits', loginWithSpace: 'Login with Space', @@ -1201,6 +1205,7 @@ const enUS = { llmCalls: 'LLM Calls', embeddingCalls: 'Embedding Calls', modelCalls: 'Model Calls', + tokens: 'Token Monitoring', feedback: 'User Feedback', sessions: 'Session Analysis', errors: 'Error Logs', @@ -1239,6 +1244,30 @@ const enUS = { avgDuration: 'Avg Duration', calls: 'Calls', }, + tokens: { + totalTokens: 'Total Tokens', + inputTokens: 'Input Tokens', + outputTokens: 'Output Tokens', + avgPerCall: 'Avg / Call', + throughput: 'Throughput', + tokensPerSec: 'tokens/sec', + errorCalls: 'Failed Calls', + acrossCalls: 'across {{count}} calls', + ofTotal: 'of {{count}} total', + usageOverTime: 'Token Usage Over Time', + byModel: 'By Model', + model: 'Model', + calls: 'Calls', + avgLatency: 'Avg Latency', + noData: 'No token usage in the selected time range', + loadError: 'Failed to load token statistics: {{error}}', + zeroTokenWarning: + '{{count}} successful call(s) reported zero token usage. This usually means the upstream provider did not return usage info — check the model provider configuration.', + bucket: { + hour: 'Hourly', + day: 'Daily', + }, + }, embeddingCalls: { title: 'Embedding Calls', model: 'Model', diff --git a/web/src/i18n/locales/es-ES.ts b/web/src/i18n/locales/es-ES.ts index 736168ee..2cb27b4e 100644 --- a/web/src/i18n/locales/es-ES.ts +++ b/web/src/i18n/locales/es-ES.ts @@ -206,6 +206,9 @@ const esES = { selectModelAbilities: 'Seleccionar capacidades del modelo', visionAbility: 'Capacidad de visión', functionCallAbility: 'Llamada a funciones', + contextLength: 'Ventana de contexto', + contextLengthPlaceholder: 'Desconocido', + contextLengthInvalid: 'La ventana de contexto debe ser un entero positivo', extraParameters: 'Parámetros adicionales', addParameter: 'Añadir parámetro', keyName: 'Nombre de la clave', diff --git a/web/src/i18n/locales/ja-JP.ts b/web/src/i18n/locales/ja-JP.ts index 5e851a91..1ef18075 100644 --- a/web/src/i18n/locales/ja-JP.ts +++ b/web/src/i18n/locales/ja-JP.ts @@ -204,6 +204,10 @@ const jaJP = { selectModelAbilities: 'モデル機能を選択', visionAbility: '視覚機能', functionCallAbility: '関数呼び出し', + contextLength: 'コンテキストウィンドウ', + contextLengthPlaceholder: '不明', + contextLengthInvalid: + 'コンテキストウィンドウは正の整数である必要があります', extraParameters: '追加パラメータ', addParameter: 'パラメータを追加', keyName: 'キー名', diff --git a/web/src/i18n/locales/ru-RU.ts b/web/src/i18n/locales/ru-RU.ts index ab1cb399..624665ec 100644 --- a/web/src/i18n/locales/ru-RU.ts +++ b/web/src/i18n/locales/ru-RU.ts @@ -203,6 +203,10 @@ const ruRU = { selectModelAbilities: 'Выберите возможности модели', visionAbility: 'Распознавание изображений', functionCallAbility: 'Вызов функций', + contextLength: 'Контекстное окно', + contextLengthPlaceholder: 'Неизвестно', + contextLengthInvalid: + 'Контекстное окно должно быть положительным целым числом', extraParameters: 'Дополнительные параметры', addParameter: 'Добавить параметр', keyName: 'Имя ключа', diff --git a/web/src/i18n/locales/th-TH.ts b/web/src/i18n/locales/th-TH.ts index 6a52b226..9d52fca7 100644 --- a/web/src/i18n/locales/th-TH.ts +++ b/web/src/i18n/locales/th-TH.ts @@ -199,6 +199,9 @@ const thTH = { selectModelAbilities: 'เลือกความสามารถของโมเดล', visionAbility: 'ความสามารถด้านภาพ', functionCallAbility: 'การเรียกฟังก์ชัน', + contextLength: 'หน้าต่างบริบท', + contextLengthPlaceholder: 'ไม่ทราบ', + contextLengthInvalid: 'หน้าต่างบริบทต้องเป็นจำนวนเต็มบวก', extraParameters: 'พารามิเตอร์เพิ่มเติม', addParameter: 'เพิ่มพารามิเตอร์', keyName: 'ชื่อคีย์', diff --git a/web/src/i18n/locales/vi-VN.ts b/web/src/i18n/locales/vi-VN.ts index 95bb84ac..c29cc60e 100644 --- a/web/src/i18n/locales/vi-VN.ts +++ b/web/src/i18n/locales/vi-VN.ts @@ -203,6 +203,9 @@ const viVN = { selectModelAbilities: 'Chọn khả năng mô hình', visionAbility: 'Khả năng thị giác', functionCallAbility: 'Gọi hàm', + contextLength: 'Cửa sổ ngữ cảnh', + contextLengthPlaceholder: 'Không rõ', + contextLengthInvalid: 'Cửa sổ ngữ cảnh phải là số nguyên dương', extraParameters: 'Tham số bổ sung', addParameter: 'Thêm tham số', keyName: 'Tên khóa', diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index da0aac2e..a7ef0a53 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -193,6 +193,9 @@ const zhHans = { selectModelAbilities: '选择模型能力', visionAbility: '视觉能力', functionCallAbility: '函数调用', + contextLength: '上下文窗口', + contextLengthPlaceholder: '未知', + contextLengthInvalid: '上下文窗口必须是正整数', extraParameters: '额外参数', addParameter: '添加参数', keyName: '键名', @@ -248,6 +251,7 @@ const zhHans = { selectProvider: '选择供应商', requester: '供应商类型', selectRequester: '选择供应商类型', + searchProviders: '搜索供应商...', langbotModelsDescription: 'LangBot Space 提供的云端模型', credits: '积分', loginWithSpace: '通过 Space 登录', @@ -1144,6 +1148,7 @@ const zhHans = { llmCalls: 'LLM调用', embeddingCalls: 'Embedding调用', modelCalls: '模型调用', + tokens: 'Token 监控', feedback: '用户反馈', sessions: '会话分析', errors: '错误日志', @@ -1182,6 +1187,30 @@ const zhHans = { avgDuration: '平均耗时', calls: '调用次数', }, + tokens: { + totalTokens: '总 Token 数', + inputTokens: '输入 Token', + outputTokens: '输出 Token', + avgPerCall: '平均每次调用', + throughput: '吞吐量', + tokensPerSec: 'Token/秒', + errorCalls: '失败调用', + acrossCalls: '共 {{count}} 次调用', + ofTotal: '共 {{count}} 次', + usageOverTime: 'Token 用量趋势', + byModel: '按模型统计', + model: '模型', + calls: '调用次数', + avgLatency: '平均延迟', + noData: '所选时间范围内暂无 Token 用量数据', + loadError: '加载 Token 统计失败:{{error}}', + zeroTokenWarning: + '检测到 {{count}} 次成功调用未上报 Token 用量(记为 0)。这通常表示上游未返回 usage 信息,请检查模型供应商配置。', + bucket: { + hour: '按小时', + day: '按天', + }, + }, embeddingCalls: { title: 'Embedding调用', model: '模型', diff --git a/web/src/i18n/locales/zh-Hant.ts b/web/src/i18n/locales/zh-Hant.ts index 95c76ebd..b7e23500 100644 --- a/web/src/i18n/locales/zh-Hant.ts +++ b/web/src/i18n/locales/zh-Hant.ts @@ -193,6 +193,9 @@ const zhHant = { selectModelAbilities: '選擇模型能力', visionAbility: '視覺能力', functionCallAbility: '函數呼叫', + contextLength: '上下文視窗', + contextLengthPlaceholder: '未知', + contextLengthInvalid: '上下文視窗必須是正整數', extraParameters: '額外參數', addParameter: '新增參數', keyName: '鍵名',