diff --git a/src/langbot/pkg/api/http/controller/groups/provider/providers.py b/src/langbot/pkg/api/http/controller/groups/provider/providers.py index b28bb3e5..d303f178 100644 --- a/src/langbot/pkg/api/http/controller/groups/provider/providers.py +++ b/src/langbot/pkg/api/http/controller/groups/provider/providers.py @@ -43,3 +43,12 @@ class ModelProvidersRouterGroup(group.RouterGroup): return self.success() except ValueError as e: return self.http_status(400, -1, str(e)) + + @self.route('//scan-models', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY) + async def _(provider_uuid: str) -> str: + try: + model_type = quart.request.args.get('type') + result = await self.ap.provider_service.scan_provider_models(provider_uuid, model_type) + return self.success(data=result) + except ValueError as e: + return self.http_status(400, -1, str(e)) diff --git a/src/langbot/pkg/api/http/service/provider.py b/src/langbot/pkg/api/http/service/provider.py index 1abb6e9f..24354731 100644 --- a/src/langbot/pkg/api/http/service/provider.py +++ b/src/langbot/pkg/api/http/service/provider.py @@ -1,6 +1,7 @@ from __future__ import annotations import uuid +import traceback import sqlalchemy @@ -164,3 +165,66 @@ class ModelProviderService: .values(api_keys=[api_key]) ) await self.ap.model_mgr.reload_provider('00000000-0000-0000-0000-000000000000') + + async def scan_provider_models(self, provider_uuid: str, model_type: str | None = None) -> dict: + provider = await self.get_provider(provider_uuid) + if provider is None: + raise ValueError('provider not found') + + runtime_provider = await self.ap.model_mgr.load_provider(provider) + + try: + scan_result = await runtime_provider.requester.scan_models( + runtime_provider.token_mgr.get_token() if runtime_provider.token_mgr.tokens else None + ) + except NotImplementedError: + raise ValueError('current provider does not support model scanning') + except Exception as exc: + self.ap.logger.warning( + f'Failed to scan models for provider {provider_uuid}: {exc}\n{traceback.format_exc()}' + ) + raise ValueError(str(exc)) from exc + + if isinstance(scan_result, dict): + scanned_models = scan_result.get('models', []) + debug_info = scan_result.get('debug') + else: + scanned_models = scan_result + debug_info = None + + llm_models = await self.ap.llm_model_service.get_llm_models_by_provider(provider_uuid) + embedding_models = await self.ap.embedding_models_service.get_embedding_models_by_provider(provider_uuid) + existing_llm_names = {model['name'] for model in llm_models} + existing_embedding_names = {model['name'] for model in embedding_models} + + filtered_models = [] + for model in scanned_models: + scanned_type = model.get('type', 'llm') + if model_type and scanned_type != model_type: + continue + + model_name = model.get('name') or model.get('id') + if not model_name: + continue + + filtered_models.append( + { + 'id': model.get('id', model_name), + 'name': model_name, + 'type': scanned_type, + 'abilities': model.get('abilities', []), + 'display_name': model.get('display_name'), + 'description': model.get('description'), + 'context_length': model.get('context_length'), + 'owned_by': model.get('owned_by'), + 'input_modalities': model.get('input_modalities', []), + 'output_modalities': model.get('output_modalities', []), + 'already_added': ( + model_name in existing_embedding_names + if scanned_type == 'embedding' + else model_name in existing_llm_names + ), + } + ) + + return {'models': filtered_models, 'debug': debug_info} diff --git a/src/langbot/pkg/provider/modelmgr/modelmgr.py b/src/langbot/pkg/provider/modelmgr/modelmgr.py index bd70ce61..bce10b6a 100644 --- a/src/langbot/pkg/provider/modelmgr/modelmgr.py +++ b/src/langbot/pkg/provider/modelmgr/modelmgr.py @@ -227,7 +227,8 @@ class ModelManager: raise provider_errors.RequesterNotFoundError(provider_entity.requester) requester_inst = self.requester_dict[provider_entity.requester]( - ap=self.ap, config={'base_url': provider_entity.base_url} + ap=self.ap, + config={'base_url': provider_entity.base_url}, ) await requester_inst.initialize() diff --git a/src/langbot/pkg/provider/modelmgr/requester.py b/src/langbot/pkg/provider/modelmgr/requester.py index c281d8ae..301bdfe9 100644 --- a/src/langbot/pkg/provider/modelmgr/requester.py +++ b/src/langbot/pkg/provider/modelmgr/requester.py @@ -303,6 +303,14 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta): async def initialize(self): pass + async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any] | list[dict[str, typing.Any]]: + """Scan models supported by the provider. + + The default implementation does not support scanning. Requesters that + can enumerate remote models should override this method. + """ + raise NotImplementedError('This provider does not support model scanning') + @abc.abstractmethod async def invoke_llm( self, diff --git a/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py index cbf5543e..24f7a200 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -31,6 +31,192 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']), ) + def _mask_api_key(self, api_key: str | None) -> str: + if not api_key: + return '' + if len(api_key) <= 8: + return '****' + return f'{api_key[:4]}...{api_key[-4:]}' + + def _infer_model_type(self, model_id: str) -> str: + normalized_model_id = (model_id or '').lower() + embedding_keywords = ( + 'embedding', + 'embed', + 'bge-', + 'e5-', + 'm3e', + 'gte-', + 'multilingual-e5', + 'text-embedding', + ) + return 'embedding' if any(keyword in normalized_model_id for keyword in embedding_keywords) else 'llm' + + def _infer_model_abilities(self, item: dict[str, typing.Any], model_id: str) -> list[str]: + normalized_model_id = (model_id or '').lower() + abilities: set[str] = set() + + def _flatten(value: typing.Any) -> list[str]: + if value is None: + return [] + if isinstance(value, str): + return [value.lower()] + if isinstance(value, dict): + flattened: list[str] = [] + for nested_value in value.values(): + flattened.extend(_flatten(nested_value)) + return flattened + if isinstance(value, (list, tuple, set)): + flattened: list[str] = [] + for nested_value in value: + flattened.extend(_flatten(nested_value)) + return flattened + return [str(value).lower()] + + capability_tokens = _flatten(item.get('capabilities')) + capability_tokens.extend(_flatten(item.get('modalities'))) + capability_tokens.extend(_flatten(item.get('input_modalities'))) + capability_tokens.extend(_flatten(item.get('output_modalities'))) + capability_tokens.extend(_flatten(item.get('supported_generation_methods'))) + capability_tokens.extend(_flatten(item.get('supported_parameters'))) + capability_tokens.extend(_flatten(item.get('architecture'))) + + combined_tokens = capability_tokens + [normalized_model_id] + + vision_keywords = ( + 'vision', + 'image', + 'file', + 'video', + 'multimodal', + 'vl', + 'ocr', + 'omni', + ) + function_call_keywords = ( + 'function', + 'tool', + 'tools', + 'tool_choice', + 'tool_call', + 'tool-use', + 'tool_use', + ) + + if any(any(keyword in token for keyword in vision_keywords) for token in combined_tokens): + abilities.add('vision') + + if any(any(keyword in token for keyword in function_call_keywords) for token in combined_tokens): + abilities.add('func_call') + + return sorted(abilities) + + def _normalize_modalities(self, value: typing.Any) -> list[str]: + normalized: list[str] = [] + + def _collect(item: typing.Any): + if item is None: + return + if isinstance(item, str): + for part in item.replace('->', ',').replace('+', ',').split(','): + token = part.strip().lower() + if token and token not in normalized: + normalized.append(token) + return + if isinstance(item, dict): + for nested in item.values(): + _collect(nested) + return + if isinstance(item, (list, tuple, set)): + for nested in item: + _collect(nested) + return + + _collect(value) + return normalized + + def _extract_scan_metadata(self, item: dict[str, typing.Any], model_id: str) -> dict[str, typing.Any]: + display_name = item.get('name') + if not isinstance(display_name, str) or not display_name.strip() or display_name == model_id: + display_name = '' + + description = item.get('description') + if not isinstance(description, str) or not description.strip(): + description = '' + + context_length = item.get('context_length') + if context_length is None and isinstance(item.get('top_provider'), dict): + context_length = item['top_provider'].get('context_length') + + if not isinstance(context_length, int): + try: + context_length = int(context_length) if context_length is not None else None + except (TypeError, ValueError): + context_length = None + + input_modalities = self._normalize_modalities(item.get('input_modalities')) + output_modalities = self._normalize_modalities(item.get('output_modalities')) + + if isinstance(item.get('architecture'), dict): + if not input_modalities: + input_modalities = self._normalize_modalities(item['architecture'].get('input_modalities')) + if not output_modalities: + output_modalities = self._normalize_modalities(item['architecture'].get('output_modalities')) + + owned_by = item.get('owned_by') + if not isinstance(owned_by, str) or not owned_by.strip(): + owned_by = '' + + return { + 'display_name': display_name or None, + 'description': description or None, + 'context_length': context_length, + 'owned_by': owned_by or None, + 'input_modalities': input_modalities, + 'output_modalities': output_modalities, + } + + async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]: + headers = {} + if api_key: + headers['Authorization'] = f'Bearer {api_key}' + + models_url = f'{self.requester_cfg["base_url"].rstrip("/")}/models' + async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client: + response = await client.get(models_url, headers=headers) + response.raise_for_status() + payload = response.json() + + models = [] + for item in payload.get('data', []): + model_id = item.get('id') + if not model_id: + continue + models.append( + { + 'id': model_id, + 'name': model_id, + 'type': self._infer_model_type(model_id), + 'abilities': self._infer_model_abilities(item, model_id), + **self._extract_scan_metadata(item, model_id), + } + ) + + models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower())) + return { + 'models': models, + 'debug': { + 'request': { + 'method': 'GET', + 'url': models_url, + 'headers': { + 'Authorization': f'Bearer {self._mask_api_key(api_key)}' if api_key else '', + }, + }, + 'response': payload, + }, + } + async def _req( self, args: dict, diff --git a/src/langbot/pkg/provider/modelmgr/requesters/geminichatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/geminichatcmpl.py index f934145e..956b49f6 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/geminichatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/geminichatcmpl.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing +import httpx from . import chatcmpl @@ -20,6 +21,68 @@ class GeminiChatCompletions(chatcmpl.OpenAIChatCompletions): 'timeout': 120, } + async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]: + models_url = 'https://generativelanguage.googleapis.com/v1beta/models' + params = {'key': api_key} if api_key else {} + + all_models: list[dict[str, typing.Any]] = [] + next_page_token = '' + last_payload: dict[str, typing.Any] = {} + + async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client: + while True: + request_params = dict(params) + if next_page_token: + request_params['pageToken'] = next_page_token + + response = await client.get(models_url, params=request_params) + response.raise_for_status() + payload = response.json() + last_payload = payload + + for item in payload.get('models', []): + model_name = item.get('name', '') + model_id = model_name.replace('models/', '', 1) + if not model_id: + continue + + supported_methods = item.get('supportedGenerationMethods', []) or [] + if 'embedContent' in supported_methods and 'generateContent' not in supported_methods: + model_type = 'embedding' + else: + model_type = 'llm' + + all_models.append( + { + 'id': model_id, + 'name': model_id, + 'type': model_type, + 'abilities': self._infer_model_abilities(item, model_id), + 'display_name': item.get('displayName') or None, + 'description': item.get('description') or None, + 'context_length': item.get('inputTokenLimit'), + 'input_modalities': self._normalize_modalities(item.get('inputModalities')), + 'output_modalities': self._normalize_modalities(item.get('outputModalities')), + } + ) + + next_page_token = payload.get('nextPageToken', '') + if not next_page_token: + break + + all_models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower())) + return { + 'models': all_models, + 'debug': { + 'request': { + 'method': 'GET', + 'url': models_url, + 'query': {'key': self._mask_api_key(api_key)} if api_key else {}, + }, + 'response': last_payload, + }, + } + async def _closure_stream( self, query: pipeline_query.Query, diff --git a/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py index 2ed4e3b5..ed5d8795 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py @@ -31,6 +31,175 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']), ) + def _mask_api_key(self, api_key: str | None) -> str: + if not api_key: + return '' + if len(api_key) <= 8: + return '****' + return f'{api_key[:4]}...{api_key[-4:]}' + + def _infer_model_type(self, model_id: str) -> str: + normalized_model_id = (model_id or '').lower() + embedding_keywords = ( + 'embedding', + 'embed', + 'bge-', + 'e5-', + 'm3e', + 'gte-', + 'multilingual-e5', + 'text-embedding', + ) + return 'embedding' if any(keyword in normalized_model_id for keyword in embedding_keywords) else 'llm' + + def _infer_model_abilities(self, item: dict[str, typing.Any], model_id: str) -> list[str]: + normalized_model_id = (model_id or '').lower() + abilities: set[str] = set() + + def _flatten(value: typing.Any) -> list[str]: + if value is None: + return [] + if isinstance(value, str): + return [value.lower()] + if isinstance(value, dict): + flattened: list[str] = [] + for nested_value in value.values(): + flattened.extend(_flatten(nested_value)) + return flattened + if isinstance(value, (list, tuple, set)): + flattened: list[str] = [] + for nested_value in value: + flattened.extend(_flatten(nested_value)) + return flattened + return [str(value).lower()] + + capability_tokens = _flatten(item.get('capabilities')) + capability_tokens.extend(_flatten(item.get('modalities'))) + capability_tokens.extend(_flatten(item.get('input_modalities'))) + capability_tokens.extend(_flatten(item.get('output_modalities'))) + capability_tokens.extend(_flatten(item.get('supported_generation_methods'))) + capability_tokens.extend(_flatten(item.get('supported_parameters'))) + capability_tokens.extend(_flatten(item.get('architecture'))) + + combined_tokens = capability_tokens + [normalized_model_id] + + vision_keywords = ('vision', 'image', 'file', 'video', 'multimodal', 'vl', 'ocr', 'omni') + function_call_keywords = ('function', 'tool', 'tools', 'tool_choice', 'tool_call', 'tool-use', 'tool_use') + + if any(any(keyword in token for keyword in vision_keywords) for token in combined_tokens): + abilities.add('vision') + + if any(any(keyword in token for keyword in function_call_keywords) for token in combined_tokens): + abilities.add('func_call') + + return sorted(abilities) + + def _normalize_modalities(self, value: typing.Any) -> list[str]: + normalized: list[str] = [] + + def _collect(item: typing.Any): + if item is None: + return + if isinstance(item, str): + for part in item.replace('->', ',').replace('+', ',').split(','): + token = part.strip().lower() + if token and token not in normalized: + normalized.append(token) + return + if isinstance(item, dict): + for nested in item.values(): + _collect(nested) + return + if isinstance(item, (list, tuple, set)): + for nested in item: + _collect(nested) + return + + _collect(value) + return normalized + + def _extract_scan_metadata(self, item: dict[str, typing.Any], model_id: str) -> dict[str, typing.Any]: + display_name = item.get('name') + if not isinstance(display_name, str) or not display_name.strip() or display_name == model_id: + display_name = '' + + description = item.get('description') + if not isinstance(description, str) or not description.strip(): + description = '' + + context_length = item.get('context_length') + if context_length is None and isinstance(item.get('top_provider'), dict): + context_length = item['top_provider'].get('context_length') + + if not isinstance(context_length, int): + try: + context_length = int(context_length) if context_length is not None else None + except (TypeError, ValueError): + context_length = None + + input_modalities = self._normalize_modalities(item.get('input_modalities')) + output_modalities = self._normalize_modalities(item.get('output_modalities')) + + if isinstance(item.get('architecture'), dict): + if not input_modalities: + input_modalities = self._normalize_modalities(item['architecture'].get('input_modalities')) + if not output_modalities: + output_modalities = self._normalize_modalities(item['architecture'].get('output_modalities')) + + owned_by = item.get('owned_by') + if not isinstance(owned_by, str) or not owned_by.strip(): + owned_by = '' + + return { + 'display_name': display_name or None, + 'description': description or None, + 'context_length': context_length, + 'owned_by': owned_by or None, + 'input_modalities': input_modalities, + 'output_modalities': output_modalities, + } + + async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]: + headers = {} + if api_key: + headers['Authorization'] = f'Bearer {api_key}' + + models_url = f'{self.requester_cfg["base_url"].rstrip("/")}/models' + async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client: + response = await client.get(models_url, headers=headers) + response.raise_for_status() + payload = response.json() + + models = [] + for item in payload.get('data', []): + model_id = item.get('id') + if not model_id: + continue + models.append( + { + 'id': model_id, + 'name': model_id, + 'type': self._infer_model_type(model_id), + 'abilities': self._infer_model_abilities(item, model_id), + **self._extract_scan_metadata(item, model_id), + } + ) + + models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower())) + return { + 'models': models, + 'debug': { + 'request': { + 'method': 'GET', + 'url': models_url, + 'headers': { + 'Authorization': f'Bearer {self._mask_api_key(api_key)}' if api_key else '', + }, + }, + 'response': payload, + }, + } + async def _req( self, query: pipeline_query.Query, diff --git a/src/langbot/pkg/provider/modelmgr/requesters/ollamachat.py b/src/langbot/pkg/provider/modelmgr/requesters/ollamachat.py index e89a65fa..50f601d7 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/ollamachat.py @@ -8,6 +8,7 @@ import uuid import json import ollama +import httpx from .. import errors, requester import langbot_plugin.api.entities.builtin.resource.tool as resource_tool @@ -31,6 +32,60 @@ class OllamaChatCompletions(requester.ProviderAPIRequester): os.environ['OLLAMA_HOST'] = self.requester_cfg['base_url'] self.client = ollama.AsyncClient(timeout=self.requester_cfg['timeout']) + def _infer_model_type(self, model_id: str) -> str: + normalized_model_id = (model_id or '').lower() + embedding_keywords = ('embedding', 'embed', 'bge-', 'e5-', 'm3e', 'gte-', 'text-embedding') + return 'embedding' if any(keyword in normalized_model_id for keyword in embedding_keywords) else 'llm' + + def _infer_model_abilities(self, item: dict[str, typing.Any], model_id: str) -> list[str]: + normalized_model_id = (model_id or '').lower() + abilities: set[str] = set() + details = item.get('details', {}) or {} + families = details.get('families', []) or [] + tokens = [normalized_model_id, str(details.get('family', '')).lower()] + tokens.extend(str(family).lower() for family in families) + + if any(keyword in token for token in tokens for keyword in ('vision', 'vl', 'omni', 'llava', 'ocr')): + abilities.add('vision') + if any(keyword in token for token in tokens for keyword in ('tool', 'function')): + abilities.add('func_call') + return sorted(abilities) + + async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]: + del api_key + models_url = f'{self.requester_cfg["base_url"].rstrip("/")}/api/tags' + + async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client: + response = await client.get(models_url) + response.raise_for_status() + payload = response.json() + + models: list[dict[str, typing.Any]] = [] + for item in payload.get('models', []): + model_id = item.get('model') or item.get('name') + if not model_id: + continue + models.append( + { + 'id': model_id, + 'name': item.get('name', model_id), + 'type': self._infer_model_type(model_id), + 'abilities': self._infer_model_abilities(item, model_id), + } + ) + + models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower())) + return { + 'models': models, + 'debug': { + 'request': { + 'method': 'GET', + 'url': models_url, + }, + 'response': payload, + }, + } + async def _req( self, args: dict, diff --git a/src/langbot/pkg/provider/modelmgr/requesters/openrouterchatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/openrouterchatcmpl.py index a73437d3..17b88431 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/openrouterchatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/openrouterchatcmpl.py @@ -15,3 +15,11 @@ class OpenRouterChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions): 'base_url': 'https://openrouter.ai/api/v1', 'timeout': 120, } + + async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]: + original_base_url = self.requester_cfg.get('base_url', '') + self.requester_cfg['base_url'] = 'https://openrouter.ai/api/v1' + try: + return await super().scan_models(api_key) + finally: + self.requester_cfg['base_url'] = original_base_url diff --git a/web/src/app/home/components/models-dialog/ModelsDialog.tsx b/web/src/app/home/components/models-dialog/ModelsDialog.tsx index e2268400..72694ad3 100644 --- a/web/src/app/home/components/models-dialog/ModelsDialog.tsx +++ b/web/src/app/home/components/models-dialog/ModelsDialog.tsx @@ -16,6 +16,8 @@ import { ProviderCard } from './components'; import { ExtraArg, ModelType, + ScanModelsResult, + SelectedScannedModel, TestResult, ProviderModels, LANGBOT_MODELS_PROVIDER_REQUESTER, @@ -262,6 +264,60 @@ export default function ModelsDialog({ } } + async function handleScanModels( + providerUuid: string, + modelType: ModelType, + ): Promise { + try { + const resp = await httpClient.scanProviderModels(providerUuid, modelType); + return { + models: resp.models, + debug: resp.debug, + }; + } catch (err) { + toast.error(t('models.getModelListError') + (err as CustomApiError).msg); + return { models: [] }; + } + } + + async function handleAddScannedModels( + providerUuid: string, + modelType: ModelType, + models: SelectedScannedModel[], + ) { + if (models.length === 0) return; + + setIsSubmitting(true); + try { + for (const item of models) { + if (modelType === 'llm') { + await httpClient.createProviderLLMModel({ + name: item.model.name, + provider_uuid: providerUuid, + abilities: item.abilities, + extra_args: {}, + } as never); + } else { + await httpClient.createProviderEmbeddingModel({ + name: item.model.name, + provider_uuid: providerUuid, + extra_args: {}, + } as never); + } + } + setAddModelPopoverOpen(null); + loadProviderModels(providerUuid, true); + loadProviders(); + toast.success( + t('models.addSelectedModelsSuccess', { count: models.length }), + ); + } catch (err) { + toast.error(t('models.createError') + (err as CustomApiError).msg); + } finally { + setIsSubmitting(false); + } + } + async function handleUpdateModel( providerUuid: string, modelId: string, @@ -404,6 +460,10 @@ export default function ModelsDialog({ onAddModel={(modelType, name, abilities, extraArgs) => handleAddModel(provider.uuid, modelType, name, abilities, extraArgs) } + onScanModels={(modelType) => handleScanModels(provider.uuid, modelType)} + onAddScannedModels={(modelType, models) => + handleAddScannedModels(provider.uuid, modelType, models) + } onOpenEditModel={(modelId) => setEditModelPopoverOpen(modelId)} onCloseEditModel={() => setEditModelPopoverOpen(null)} onUpdateModel={(modelId, modelType, name, abilities, extraArgs) => diff --git a/web/src/app/home/components/models-dialog/component/provider-form/ProviderForm.tsx b/web/src/app/home/components/models-dialog/component/provider-form/ProviderForm.tsx index 94e4b0f9..df3aea26 100644 --- a/web/src/app/home/components/models-dialog/component/provider-form/ProviderForm.tsx +++ b/web/src/app/home/components/models-dialog/component/provider-form/ProviderForm.tsx @@ -169,8 +169,6 @@ export default function ProviderForm({ onValueChange={(v) => { field.onChange(v); const req = requesterList.find((r) => r.value === v); - // Auto-fill default URL when creating new provider - // or when base_url is empty in edit mode if (req && (!providerId || !form.getValues('base_url'))) { form.setValue('base_url', req.defaultUrl); } diff --git a/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx b/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx index 91151b11..bdbf90a8 100644 --- a/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx +++ b/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx @@ -1,5 +1,14 @@ -import { useState, useEffect } from 'react'; -import { Plus, MessageSquareText, Cpu, Eye, Wrench, Check } from 'lucide-react'; +import { useState, useEffect, useRef } from 'react'; +import { + Plus, + MessageSquareText, + Cpu, + Eye, + Wrench, + Check, + RefreshCw, + Search, +} from 'lucide-react'; import { Button } from '@/components/ui/button'; import { Input } from '@/components/ui/input'; import { Label } from '@/components/ui/label'; @@ -11,7 +20,14 @@ import { } from '@/components/ui/popover'; import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; import { useTranslation } from 'react-i18next'; -import { ExtraArg, ModelType, TestResult } from '../types'; +import { ScannedProviderModel } from '@/app/infra/entities/api'; +import { + ExtraArg, + ModelType, + ScanModelsResult, + SelectedScannedModel, + TestResult, +} from '../types'; import ExtraArgsEditor from './ExtraArgsEditor'; interface AddModelPopoverProps { @@ -24,6 +40,11 @@ interface AddModelPopoverProps { abilities: string[], extraArgs: ExtraArg[], ) => Promise; + onScanModels: (modelType: ModelType) => Promise; + onAddScannedModels: ( + modelType: ModelType, + models: SelectedScannedModel[], + ) => Promise; onTestModel: ( name: string, modelType: ModelType, @@ -41,6 +62,8 @@ export default function AddModelPopover({ onOpen, onClose, onAddModel, + onScanModels, + onAddScannedModels, onTestModel, isSubmitting, isTesting, @@ -48,22 +71,44 @@ export default function AddModelPopover({ onResetTestResult, }: AddModelPopoverProps) { const { t } = useTranslation(); + const prevIsOpenRef = useRef(false); const [tab, setTab] = useState('llm'); + const [mode, setMode] = useState<'manual' | 'scan'>('manual'); const [name, setName] = useState(''); const [abilities, setAbilities] = useState([]); const [extraArgs, setExtraArgs] = useState([]); + const [scanLoading, setScanLoading] = useState(false); + const [scannedModels, setScannedModels] = useState( + [], + ); + const [selectedScannedModels, setSelectedScannedModels] = useState< + Record + >({}); + const [scanQuery, setScanQuery] = useState(''); - // Reset form when popover opens useEffect(() => { - if (isOpen) { + const wasOpen = prevIsOpenRef.current; + if (isOpen && !wasOpen) { setTab('llm'); + setMode('manual'); setName(''); setAbilities([]); setExtraArgs([]); + setScanLoading(false); + setScannedModels([]); + setSelectedScannedModels({}); + setScanQuery(''); onResetTestResult(); } - }, [isOpen]); + prevIsOpenRef.current = isOpen; + }, [isOpen, onResetTestResult]); + + useEffect(() => { + setScannedModels([]); + setSelectedScannedModels({}); + setScanQuery(''); + }, [tab, mode]); const handleAdd = async () => { await onAddModel(tab, name, abilities, extraArgs); @@ -73,6 +118,50 @@ export default function AddModelPopover({ await onTestModel(name, tab, tab === 'llm' ? abilities : [], extraArgs); }; + const handleScan = async () => { + setScanLoading(true); + try { + const result = await onScanModels(tab); + + // Enrich abilities from debug.response.data (e.g. features.tools.function_calling) + const debugData = ( + result.debug?.response as { data?: Record[] } + )?.data; + if (Array.isArray(debugData)) { + const debugMap = new Map>(); + for (const item of debugData) { + if (typeof item?.id === 'string') { + debugMap.set(item.id, item); + } + } + for (const model of result.models) { + const debugItem = debugMap.get(model.id); + if (!debugItem) continue; + const features = debugItem.features as + | Record + | undefined; + const tools = features?.tools as Record | undefined; + if (tools?.function_calling === true) { + const abilities = new Set(model.abilities || []); + abilities.add('func_call'); + model.abilities = [...abilities]; + } + } + } + + setScannedModels(result.models); + setSelectedScannedModels({}); + } finally { + setScanLoading(false); + } + }; + + const handleAddScanned = async () => { + const selectedModels = Object.values(selectedScannedModels); + if (selectedModels.length === 0) return; + await onAddScannedModels(tab, selectedModels); + }; + const toggleAbility = (ability: string, checked: boolean) => { if (checked) { setAbilities([...abilities, ability]); @@ -81,6 +170,76 @@ export default function AddModelPopover({ } }; + const toggleScannedModel = ( + model: ScannedProviderModel, + checked: boolean, + ) => { + setSelectedScannedModels((prev) => { + const next = { ...prev }; + if (checked) { + next[model.id] = { + model, + abilities: + model.type === 'llm' + ? prev[model.id]?.abilities || model.abilities || [] + : [], + }; + } else { + delete next[model.id]; + } + return next; + }); + }; + + const toggleScannedModelAbility = ( + modelId: string, + ability: string, + checked: boolean, + ) => { + setSelectedScannedModels((prev) => { + const current = prev[modelId]; + if (!current) return prev; + + const nextAbilities = checked + ? [...current.abilities, ability] + : current.abilities.filter((item) => item !== ability); + + return { + ...prev, + [modelId]: { + ...current, + abilities: nextAbilities, + }, + }; + }); + }; + + const filteredScannedModels = scannedModels.filter((model) => + model.name.toLowerCase().includes(scanQuery.trim().toLowerCase()), + ); + + const selectableModels = filteredScannedModels.filter( + (m) => !m.already_added, + ); + const allSelected = + selectableModels.length > 0 && + selectableModels.every((m) => Boolean(selectedScannedModels[m.id])); + + const toggleSelectAll = () => { + if (allSelected) { + setSelectedScannedModels({}); + } else { + const next: Record = {}; + for (const model of selectableModels) { + next[model.id] = { + model, + abilities: model.type === 'llm' ? model.abilities || [] : [], + }; + } + setSelectedScannedModels(next); + } + }; + return ( e.stopPropagation()} > setTab(v as ModelType)}> @@ -114,116 +276,260 @@ export default function AddModelPopover({ - -
- - setName(e.target.value)} - /> -
-
- -
-
- - toggleAbility('vision', checked as boolean) - } + setMode(v as 'manual' | 'scan')} + > + + {t('models.manualAdd')} + {t('models.scanAdd')} + + + +
+
+ + setName(e.target.value)} /> -
-
- - toggleAbility('func_call', checked as boolean) - } - /> - + + {tab === 'llm' && ( +
+ +
+
+ + toggleAbility('vision', checked as boolean) + } + /> + +
+
+ + toggleAbility('func_call', checked as boolean) + } + /> + +
+
+
+ )} + + +
+ +
-
- -
- - -
-
+ - -
- - setName(e.target.value)} - /> -
- -
- - + +
+ +
+ + setScanQuery(e.target.value)} + disabled={scannedModels.length === 0} + /> + {selectableModels.length > 0 && ( +
+ + +
)} - -
-
+
+ +
e.stopPropagation()} + > +
+ {filteredScannedModels.length === 0 ? ( +

+ {scannedModels.length === 0 + ? t('models.noScannedModels') + : t('models.noScannedModelsMatch')} +

+ ) : ( + filteredScannedModels.map((model) => { + const isSelected = Boolean( + selectedScannedModels[model.id], + ); + const selectedAbilities = + selectedScannedModels[model.id]?.abilities || []; + return ( +
+
+ + toggleScannedModel(model, checked as boolean) + } + /> +
+
+ {model.name} +
+
+ {model.already_added + ? t('models.alreadyAdded') + : model.type === 'llm' + ? t('models.chat') + : t('models.embedding')} +
+
+
+ + {tab === 'llm' && + isSelected && + !model.already_added && ( +
+
+ + toggleScannedModelAbility( + model.id, + 'vision', + checked as boolean, + ) + } + /> + +
+
+ + toggleScannedModelAbility( + model.id, + 'func_call', + checked as boolean, + ) + } + /> + +
+
+ )} +
+ ); + }) + )} +
+
+ + diff --git a/web/src/app/home/components/models-dialog/components/ProviderCard.tsx b/web/src/app/home/components/models-dialog/components/ProviderCard.tsx index 70adebe6..3c1de7a9 100644 --- a/web/src/app/home/components/models-dialog/components/ProviderCard.tsx +++ b/web/src/app/home/components/models-dialog/components/ProviderCard.tsx @@ -24,7 +24,14 @@ import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; import { Badge } from '@/components/ui/badge'; import { useTranslation } from 'react-i18next'; import langbotIcon from '@/app/assets/langbot-logo.webp'; -import { ExtraArg, ModelType, TestResult, ProviderModels } from '../types'; +import { + ExtraArg, + ModelType, + ScanModelsResult, + SelectedScannedModel, + TestResult, + ProviderModels, +} from '../types'; import ModelItem from './ModelItem'; import AddModelPopover from './AddModelPopover'; @@ -53,6 +60,11 @@ interface ProviderCardProps { abilities: string[], extraArgs: ExtraArg[], ) => Promise; + onScanModels: (modelType: ModelType) => Promise; + onAddScannedModels: ( + modelType: ModelType, + models: SelectedScannedModel[], + ) => Promise; onOpenEditModel: (modelId: string) => void; onCloseEditModel: () => void; onUpdateModel: ( @@ -101,6 +113,8 @@ export default function ProviderCard({ onOpenAddModel, onCloseAddModel, onAddModel, + onScanModels, + onAddScannedModels, onOpenEditModel, onCloseEditModel, onUpdateModel, @@ -298,6 +312,8 @@ export default function ProviderCard({ onOpen={onOpenAddModel} onClose={onCloseAddModel} onAddModel={onAddModel} + onScanModels={onScanModels} + onAddScannedModels={onAddScannedModels} onTestModel={onTestModel} isSubmitting={isSubmitting} isTesting={isTesting} diff --git a/web/src/app/home/components/models-dialog/types.ts b/web/src/app/home/components/models-dialog/types.ts index 15217269..ea52f687 100644 --- a/web/src/app/home/components/models-dialog/types.ts +++ b/web/src/app/home/components/models-dialog/types.ts @@ -2,6 +2,8 @@ import { LLMModel, EmbeddingModel, ModelProvider, + ProviderScanDebugInfo, + ScannedProviderModel, } from '@/app/infra/entities/api'; export type ExtraArg = { @@ -22,6 +24,16 @@ export interface TestResult { duration: number; } +export type SelectedScannedModel = { + model: ScannedProviderModel; + abilities: string[]; +}; + +export type ScanModelsResult = { + models: ScannedProviderModel[]; + debug?: ProviderScanDebugInfo; +}; + export interface ModelItemProps { model: LLMModel | EmbeddingModel; modelType: ModelType; @@ -75,6 +87,11 @@ export interface ProviderCardProps { abilities: string[], extraArgs: ExtraArg[], ) => Promise; + onScanModels: (modelType: ModelType) => Promise; + onAddScannedModels: ( + modelType: ModelType, + models: SelectedScannedModel[], + ) => Promise; onOpenEditModel: (modelId: string) => void; onCloseEditModel: () => void; onUpdateModel: ( diff --git a/web/src/app/infra/entities/api/index.ts b/web/src/app/infra/entities/api/index.ts index a801a9b7..626f6865 100644 --- a/web/src/app/infra/entities/api/index.ts +++ b/web/src/app/infra/entities/api/index.ts @@ -61,6 +61,34 @@ export interface ApiRespModelProvider { provider: ModelProvider; } +export interface ScannedProviderModel { + id: string; + name: string; + type: 'llm' | 'embedding'; + abilities?: string[]; + display_name?: string; + description?: string; + context_length?: number | null; + owned_by?: string; + input_modalities?: string[]; + output_modalities?: string[]; + already_added: boolean; +} + +export interface ProviderScanDebugInfo { + request?: { + method?: string; + url?: string; + headers?: Record; + }; + response?: unknown; +} + +export interface ApiRespScannedProviderModels { + models: ScannedProviderModel[]; + debug?: ProviderScanDebugInfo; +} + export interface LLMModel { uuid: string; name: string; diff --git a/web/src/app/infra/http/BackendClient.ts b/web/src/app/infra/http/BackendClient.ts index ecc0cce3..97833a41 100644 --- a/web/src/app/infra/http/BackendClient.ts +++ b/web/src/app/infra/http/BackendClient.ts @@ -37,6 +37,7 @@ import { MCPServer, ApiRespModelProviders, ApiRespModelProvider, + ApiRespScannedProviderModels, ModelProvider, ApiRespKnowledgeEngines, ApiRespParsers, @@ -106,6 +107,14 @@ export class BackendClient extends BaseHttpClient { return this.delete(`/api/v1/provider/providers/${uuid}`); } + public scanProviderModels( + uuid: string, + modelType?: 'llm' | 'embedding', + ): Promise { + const params = modelType ? { type: modelType } : {}; + return this.get(`/api/v1/provider/providers/${uuid}/scan-models`, params); + } + // ============ Provider Model LLM ============ public getProviderLLMModels( providerUuid?: string, diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index b888f398..c2644983 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -181,6 +181,10 @@ const enUS = { mustBeValidNumber: 'Must be a valid number', mustBeTrueOrFalse: 'Must be true or false', requestURL: 'Request URL', + scanURL: 'Scan Models URL', + scanURLPlaceholder: 'Leave empty to use Request URL + /models', + scanURLDescription: + 'Fill in the actual model-list endpoint when model scanning does not use the same address as model invocation.', apiKey: 'API Key', abilities: 'Abilities', selectModelAbilities: 'Select model abilities', @@ -218,6 +222,20 @@ const enUS = { providerCount: '{{count}} providers', // New keys for provider-based structure addModel: 'Add Model', + manualAdd: 'Manual', + scanAdd: 'Scan', + scanModels: 'Scan Models', + scanModelsHint: + 'Read available models from the current provider, then select which ones to add.', + scannedModels: 'Scanned Models', + scanDebug: 'Debug Info', + searchScannedModels: 'Search scanned models', + noScannedModels: 'No scan results yet. Click the button above to scan.', + noScannedModelsMatch: 'No matching models', + addSelectedModels: 'Add Selected', + addSelectedModelsSuccess: '{{count}} model(s) added', + selectAll: 'Select All', + alreadyAdded: 'Already added', addLLMModel: 'Add LLM Model', addEmbeddingModel: 'Add Embedding Model', provider: 'Provider', diff --git a/web/src/i18n/locales/es-ES.ts b/web/src/i18n/locales/es-ES.ts index 37bc2f73..e3b02bc2 100644 --- a/web/src/i18n/locales/es-ES.ts +++ b/web/src/i18n/locales/es-ES.ts @@ -227,6 +227,20 @@ const esES = { providerCount: '{{count}} proveedores', // New keys for provider-based structure addModel: 'Añadir modelo', + manualAdd: 'Manual', + scanAdd: 'Escanear', + scanModels: 'Escanear modelos', + scanModelsHint: + 'Lee los modelos disponibles del proveedor actual y luego elige cuáles agregar.', + scannedModels: 'Modelos detectados', + searchScannedModels: 'Buscar modelos detectados', + noScannedModels: + 'Todavía no hay resultados. Pulsa el botón superior para escanear.', + noScannedModelsMatch: 'No hay modelos coincidentes', + addSelectedModels: 'Agregar seleccionados', + addSelectedModelsSuccess: 'Se agregaron {{count}} modelo(s)', + selectAll: 'Seleccionar todo', + alreadyAdded: 'Ya agregado', addLLMModel: 'Añadir modelo LLM', addEmbeddingModel: 'Añadir modelo Embedding', provider: 'Proveedor', diff --git a/web/src/i18n/locales/ja-JP.ts b/web/src/i18n/locales/ja-JP.ts index dba99564..2d1de5f3 100644 --- a/web/src/i18n/locales/ja-JP.ts +++ b/web/src/i18n/locales/ja-JP.ts @@ -221,6 +221,20 @@ 'ローカルモデルがありません。作成ボタンをクリックしてモデルを追加してください。', providerCount: '{{count}} 件のプロバイダー', addModel: 'モデルを追加', + manualAdd: '手動追加', + scanAdd: 'スキャン追加', + scanModels: 'モデルをスキャン', + scanModelsHint: + '現在のプロバイダーから利用可能なモデルを取得し、追加するモデルを選択します。', + scannedModels: 'スキャン結果', + searchScannedModels: 'スキャン結果を検索', + noScannedModels: + 'まだスキャン結果がありません。上のボタンからスキャンしてください。', + noScannedModelsMatch: '一致するモデルがありません', + addSelectedModels: '選択したモデルを追加', + addSelectedModelsSuccess: '{{count}} 件のモデルを追加しました', + selectAll: 'すべて選択', + alreadyAdded: '追加済み', addLLMModel: 'LLMモデルを追加', addEmbeddingModel: '埋め込みモデルを追加', provider: 'プロバイダー', diff --git a/web/src/i18n/locales/th-TH.ts b/web/src/i18n/locales/th-TH.ts index 9361bfdb..51e7698e 100644 --- a/web/src/i18n/locales/th-TH.ts +++ b/web/src/i18n/locales/th-TH.ts @@ -215,6 +215,19 @@ const thTH = { noLocalModels: 'ไม่มีโมเดลท้องถิ่น คลิกสร้างเพื่อเพิ่มโมเดล', providerCount: '{{count}} ผู้ให้บริการ', addModel: 'เพิ่มโมเดล', + manualAdd: 'เพิ่มเอง', + scanAdd: 'สแกน', + scanModels: 'สแกนโมเดล', + scanModelsHint: + 'ดึงรายการโมเดลที่ใช้ได้จากผู้ให้บริการปัจจุบัน แล้วเลือกโมเดลที่ต้องการเพิ่ม', + scannedModels: 'ผลการสแกน', + searchScannedModels: 'ค้นหาผลการสแกน', + noScannedModels: 'ยังไม่มีผลการสแกน กดปุ่มด้านบนเพื่อเริ่มสแกน', + noScannedModelsMatch: 'ไม่พบโมเดลที่ตรงกัน', + addSelectedModels: 'เพิ่มที่เลือก', + addSelectedModelsSuccess: 'เพิ่มแล้ว {{count}} โมเดล', + selectAll: 'เลือกทั้งหมด', + alreadyAdded: 'เพิ่มแล้ว', addLLMModel: 'เพิ่มโมเดล LLM', addEmbeddingModel: 'เพิ่มโมเดล Embedding', provider: 'ผู้ให้บริการ', diff --git a/web/src/i18n/locales/vi-VN.ts b/web/src/i18n/locales/vi-VN.ts index c86f053b..5eee1a87 100644 --- a/web/src/i18n/locales/vi-VN.ts +++ b/web/src/i18n/locales/vi-VN.ts @@ -222,6 +222,19 @@ const viVN = { noLocalModels: 'Không có mô hình cục bộ. Nhấn Tạo để thêm mô hình.', providerCount: '{{count}} nhà cung cấp', addModel: 'Thêm mô hình', + manualAdd: 'Thủ công', + scanAdd: 'Quét', + scanModels: 'Quét mô hình', + scanModelsHint: + 'Đọc danh sách mô hình khả dụng từ nhà cung cấp hiện tại rồi chọn mô hình cần thêm.', + scannedModels: 'Kết quả quét', + searchScannedModels: 'Tìm trong kết quả quét', + noScannedModels: 'Chưa có kết quả quét. Nhấn nút phía trên để bắt đầu.', + noScannedModelsMatch: 'Không có mô hình phù hợp', + addSelectedModels: 'Thêm mục đã chọn', + addSelectedModelsSuccess: 'Đã thêm {{count}} mô hình', + selectAll: 'Chọn tất cả', + alreadyAdded: 'Đã thêm', addLLMModel: 'Thêm mô hình LLM', addEmbeddingModel: 'Thêm mô hình Embedding', provider: 'Nhà cung cấp', diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index 26820d8e..bcc5faa1 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -173,6 +173,10 @@ const zhHans = { mustBeValidNumber: '必须是有效的数字', mustBeTrueOrFalse: '必须是 true 或 false', requestURL: '请求URL', + scanURL: '扫描模型地址', + scanURLPlaceholder: '留空则默认使用请求URL + /models', + scanURLDescription: + '当模型扫描接口与模型调用接口不是同一个地址时,在这里填写实际的模型列表接口。', apiKey: 'API Key', abilities: '能力', selectModelAbilities: '选择模型能力', @@ -209,6 +213,19 @@ const zhHans = { providerCount: '共 {{count}} 个自定义供应商', // 供应商结构新增键 addModel: '添加模型', + manualAdd: '手动添加', + scanAdd: '扫描添加', + scanModels: '扫描模型', + scanModelsHint: '从当前供应商接口读取可用模型,然后勾选要添加的模型。', + scannedModels: '扫描结果', + scanDebug: '调试信息', + searchScannedModels: '搜索扫描结果', + noScannedModels: '还没有扫描结果,点击上方按钮开始扫描。', + noScannedModelsMatch: '没有匹配的模型', + addSelectedModels: '添加所选模型', + addSelectedModelsSuccess: '已添加 {{count}} 个模型', + selectAll: '全选模型', + alreadyAdded: '已添加', addLLMModel: '添加对话模型', addEmbeddingModel: '添加嵌入模型', provider: '供应商', diff --git a/web/src/i18n/locales/zh-Hant.ts b/web/src/i18n/locales/zh-Hant.ts index c14d9c43..c8c4a906 100644 --- a/web/src/i18n/locales/zh-Hant.ts +++ b/web/src/i18n/locales/zh-Hant.ts @@ -208,6 +208,18 @@ const zhHant = { noLocalModels: '暫無本地模型。點擊建立按鈕新增模型。', providerCount: '共 {{count}} 個供應商', addModel: '新增模型', + manualAdd: '手動添加', + scanAdd: '掃描添加', + scanModels: '掃描模型', + scanModelsHint: '從目前供應商介面讀取可用模型,然後勾選要添加的模型。', + scannedModels: '掃描結果', + searchScannedModels: '搜尋掃描結果', + noScannedModels: '尚無掃描結果,點擊上方按鈕開始掃描。', + noScannedModelsMatch: '沒有符合的模型', + addSelectedModels: '添加所選模型', + addSelectedModelsSuccess: '已添加 {{count}} 個模型', + selectAll: '全選模型', + alreadyAdded: '已添加', addLLMModel: '新增對話模型', addEmbeddingModel: '新增嵌入模型', provider: '供應商',