diff --git a/src/langbot/pkg/api/http/controller/groups/provider/models.py b/src/langbot/pkg/api/http/controller/groups/provider/models.py index 25f16995..cec582ee 100644 --- a/src/langbot/pkg/api/http/controller/groups/provider/models.py +++ b/src/langbot/pkg/api/http/controller/groups/provider/models.py @@ -9,12 +9,15 @@ class LLMModelsRouterGroup(group.RouterGroup): @self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY) async def _() -> str: if quart.request.method == 'GET': + provider_uuid = quart.request.args.get('provider_uuid') + if provider_uuid: + return self.success( + data={'models': await self.ap.llm_model_service.get_llm_models_by_provider(provider_uuid)} + ) return self.success(data={'models': await self.ap.llm_model_service.get_llm_models()}) elif quart.request.method == 'POST': json_data = await quart.request.json - model_uuid = await self.ap.llm_model_service.create_llm_model(json_data) - return self.success(data={'uuid': model_uuid}) @self.route('/', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY) @@ -52,12 +55,19 @@ class EmbeddingModelsRouterGroup(group.RouterGroup): @self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY) async def _() -> str: if quart.request.method == 'GET': + provider_uuid = quart.request.args.get('provider_uuid') + if provider_uuid: + return self.success( + data={ + 'models': await self.ap.embedding_models_service.get_embedding_models_by_provider( + provider_uuid + ) + } + ) return self.success(data={'models': await self.ap.embedding_models_service.get_embedding_models()}) elif quart.request.method == 'POST': json_data = await quart.request.json - model_uuid = await self.ap.embedding_models_service.create_embedding_model(json_data) - return self.success(data={'uuid': model_uuid}) @self.route('/', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY) diff --git a/src/langbot/pkg/api/http/controller/groups/provider/providers.py b/src/langbot/pkg/api/http/controller/groups/provider/providers.py new file mode 100644 index 00000000..b28bb3e5 --- /dev/null +++ b/src/langbot/pkg/api/http/controller/groups/provider/providers.py @@ -0,0 +1,45 @@ +import quart + +from ... import group + + +@group.group_class('models/providers', '/api/v1/provider/providers') +class ModelProvidersRouterGroup(group.RouterGroup): + async def initialize(self) -> None: + @self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY) + async def _() -> str: + if quart.request.method == 'GET': + providers = await self.ap.provider_service.get_providers() + # Add model counts + for provider in providers: + counts = await self.ap.provider_service.get_provider_model_counts(provider['uuid']) + provider['llm_count'] = counts['llm_count'] + provider['embedding_count'] = counts['embedding_count'] + return self.success(data={'providers': providers}) + elif quart.request.method == 'POST': + json_data = await quart.request.json + provider_uuid = await self.ap.provider_service.create_provider(json_data) + return self.success(data={'uuid': provider_uuid}) + + @self.route( + '/', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY + ) + async def _(provider_uuid: str) -> str: + if quart.request.method == 'GET': + provider = await self.ap.provider_service.get_provider(provider_uuid) + if provider is None: + return self.http_status(404, -1, 'provider not found') + counts = await self.ap.provider_service.get_provider_model_counts(provider_uuid) + provider['llm_count'] = counts['llm_count'] + provider['embedding_count'] = counts['embedding_count'] + return self.success(data={'provider': provider}) + elif quart.request.method == 'PUT': + json_data = await quart.request.json + await self.ap.provider_service.update_provider(provider_uuid, json_data) + return self.success() + elif quart.request.method == 'DELETE': + try: + await self.ap.provider_service.delete_provider(provider_uuid) + return self.success() + except ValueError as e: + return self.http_status(400, -1, str(e)) diff --git a/src/langbot/pkg/api/http/controller/groups/space.py b/src/langbot/pkg/api/http/controller/groups/space.py deleted file mode 100644 index cefdce19..00000000 --- a/src/langbot/pkg/api/http/controller/groups/space.py +++ /dev/null @@ -1,52 +0,0 @@ -import quart - -from .. import group - - -DEFAULT_SPACE_URL = 'https://space.langbot.app' - - -@group.group_class('space', '/api/v1/space') -class SpaceRouterGroup(group.RouterGroup): - async def initialize(self) -> None: - @self.route('/models/sync', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) - async def _(user_email: str) -> str: - """Sync models from Space MaaS to local database""" - json_data = await quart.request.json or {} - space_url = json_data.get('space_url', DEFAULT_SPACE_URL) - - try: - stats = await self.ap.space_models_service.sync_models_from_space(user_email, space_url) - return self.success(data=stats) - except ValueError as e: - return self.fail(1, str(e)) - except Exception as e: - return self.fail(2, f'Failed to sync models: {str(e)}') - - @self.route('/models', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) - async def _(user_email: str) -> str: - """Get all synced Space models""" - if quart.request.method == 'GET': - try: - models = await self.ap.space_models_service.get_space_models() - return self.success(data=models) - except Exception as e: - return self.fail(1, f'Failed to get Space models: {str(e)}') - elif quart.request.method == 'DELETE': - try: - stats = await self.ap.space_models_service.delete_space_models() - return self.success(data=stats) - except Exception as e: - return self.fail(1, f'Failed to delete Space models: {str(e)}') - - @self.route('/models/available', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) - async def _(user_email: str) -> str: - """Get available models from Space (preview before sync)""" - try: - space_url = quart.request.args.get('space_url', DEFAULT_SPACE_URL) - models_data = await self.ap.space_models_service.fetch_space_models(space_url) - return self.success(data=models_data) - except ValueError as e: - return self.fail(1, str(e)) - except Exception as e: - return self.fail(2, f'Failed to fetch available models: {str(e)}') diff --git a/src/langbot/pkg/api/http/service/model.py b/src/langbot/pkg/api/http/service/model.py index 288f07ad..03f42e3d 100644 --- a/src/langbot/pkg/api/http/service/model.py +++ b/src/langbot/pkg/api/http/service/model.py @@ -11,6 +11,18 @@ from ....entity.persistence import pipeline as persistence_pipeline from ....provider.modelmgr import requester as model_requester +def _parse_provider_api_keys(provider_dict: dict) -> dict: + """Parse api_keys if it's a JSON string""" + if isinstance(provider_dict.get('api_keys'), str): + import json + + try: + provider_dict['api_keys'] = json.loads(provider_dict['api_keys']) + except Exception: + provider_dict['api_keys'] = [] + return provider_dict + + class LLMModelsService: ap: app.Application @@ -18,29 +30,64 @@ class LLMModelsService: self.ap = ap async def get_llm_models(self, include_secret: bool = True) -> list[dict]: + """Get all LLM models with provider info""" result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel)) - models = result.all() - masked_columns = [] - if not include_secret: - masked_columns = ['api_keys'] + # Get all providers for lookup + providers_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.ModelProvider) + ) + providers = {p.uuid: p for p in providers_result.all()} - return [ - self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model, masked_columns) - for model in models - ] + models_list = [] + for model in models: + model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) + provider = providers.get(model.provider_uuid) + if provider: + provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider) + provider_dict = _parse_provider_api_keys(provider_dict) + if not include_secret: + provider_dict['api_keys'] = ['***'] * len(provider_dict.get('api_keys', [])) + model_dict['provider'] = provider_dict + models_list.append(model_dict) + + return models_list + + async def get_llm_models_by_provider(self, provider_uuid: str) -> list[dict]: + """Get LLM models by provider UUID""" + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.LLMModel).where( + persistence_model.LLMModel.provider_uuid == provider_uuid + ) + ) + models = result.all() + return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, m) for m in models] async def create_llm_model(self, model_data: dict) -> str: + """Create a new LLM model""" model_data['uuid'] = str(uuid.uuid4()) + # Handle provider creation if needed + if 'provider' in model_data: + provider_data = model_data.pop('provider') + if provider_data.get('uuid'): + model_data['provider_uuid'] = provider_data['uuid'] + else: + # Create new provider + provider_uuid = await self.ap.provider_service.find_or_create_provider( + requester=provider_data.get('requester', ''), + base_url=provider_data.get('base_url', ''), + api_keys=provider_data.get('api_keys', []), + ) + model_data['provider_uuid'] = provider_uuid + await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_model.LLMModel).values(**model_data)) llm_model = await self.get_llm_model(model_data['uuid']) - await self.ap.model_mgr.load_llm_model(llm_model) - # check if default pipeline has no model bound + # Check if default pipeline has no model bound result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_pipeline.LegacyPipeline).where( persistence_pipeline.LegacyPipeline.is_default == True @@ -56,21 +103,47 @@ class LLMModelsService: return model_data['uuid'] async def get_llm_model(self, model_uuid: str) -> dict | None: + """Get a single LLM model with provider info""" result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid) ) - model = result.first() - if model is None: return None - return self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) + model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) + + # Get provider + provider_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.ModelProvider).where( + persistence_model.ModelProvider.uuid == model.provider_uuid + ) + ) + provider = provider_result.first() + if provider: + provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider) + model_dict['provider'] = _parse_provider_api_keys(provider_dict) + + return model_dict async def update_llm_model(self, model_uuid: str, model_data: dict) -> None: + """Update an existing LLM model""" if 'uuid' in model_data: del model_data['uuid'] + # Handle provider update if needed + if 'provider' in model_data: + provider_data = model_data.pop('provider') + if provider_data.get('uuid'): + model_data['provider_uuid'] = provider_data['uuid'] + else: + provider_uuid = await self.ap.provider_service.find_or_create_provider( + requester=provider_data.get('requester', ''), + base_url=provider_data.get('base_url', ''), + api_keys=provider_data.get('api_keys', []), + ) + model_data['provider_uuid'] = provider_uuid + await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_model.LLMModel) .where(persistence_model.LLMModel.uuid == model_uuid) @@ -78,19 +151,18 @@ class LLMModelsService: ) await self.ap.model_mgr.remove_llm_model(model_uuid) - llm_model = await self.get_llm_model(model_uuid) - await self.ap.model_mgr.load_llm_model(llm_model) async def delete_llm_model(self, model_uuid: str) -> None: + """Delete an LLM model""" await self.ap.persistence_mgr.execute_async( sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid) ) - await self.ap.model_mgr.remove_llm_model(model_uuid) async def test_llm_model(self, model_uuid: str, model_data: dict) -> None: + """Test an LLM model""" runtime_llm_model: model_requester.RuntimeLLMModel | None = None if model_uuid != '_': @@ -98,18 +170,11 @@ class LLMModelsService: if model.model_entity.uuid == model_uuid: runtime_llm_model = model break - if runtime_llm_model is None: raise Exception('model not found') - else: runtime_llm_model = await self.ap.model_mgr.init_runtime_llm_model(model_data) - # Mon Nov 10 2025: Commented for some providers may not support thinking parameter - # # 有些模型厂商默认开启了思考功能,测试容易延迟 - # extra_args = model_data.get('extra_args', {}) - # if not extra_args or 'thinking' not in extra_args: - # extra_args['thinking'] = {'type': 'disabled'} extra_args = model_data.get('extra_args', {}) await runtime_llm_model.requester.invoke_llm( query=None, @@ -127,42 +192,103 @@ class EmbeddingModelsService: self.ap = ap async def get_embedding_models(self) -> list[dict]: + """Get all embedding models with provider info""" result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel)) - models = result.all() - return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) for model in models] + + providers_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.ModelProvider) + ) + providers = {p.uuid: p for p in providers_result.all()} + + models_list = [] + for model in models: + model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) + provider = providers.get(model.provider_uuid) + if provider: + provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider) + model_dict['provider'] = _parse_provider_api_keys(provider_dict) + models_list.append(model_dict) + + return models_list + + async def get_embedding_models_by_provider(self, provider_uuid: str) -> list[dict]: + """Get embedding models by provider UUID""" + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.EmbeddingModel).where( + persistence_model.EmbeddingModel.provider_uuid == provider_uuid + ) + ) + models = result.all() + return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, m) for m in models] async def create_embedding_model(self, model_data: dict) -> str: + """Create a new embedding model""" model_data['uuid'] = str(uuid.uuid4()) + if 'provider' in model_data: + provider_data = model_data.pop('provider') + if provider_data.get('uuid'): + model_data['provider_uuid'] = provider_data['uuid'] + else: + provider_uuid = await self.ap.provider_service.find_or_create_provider( + requester=provider_data.get('requester', ''), + base_url=provider_data.get('base_url', ''), + api_keys=provider_data.get('api_keys', []), + ) + model_data['provider_uuid'] = provider_uuid + await self.ap.persistence_mgr.execute_async( sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data) ) embedding_model = await self.get_embedding_model(model_data['uuid']) - await self.ap.model_mgr.load_embedding_model(embedding_model) return model_data['uuid'] async def get_embedding_model(self, model_uuid: str) -> dict | None: + """Get a single embedding model with provider info""" result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_model.EmbeddingModel).where( persistence_model.EmbeddingModel.uuid == model_uuid ) ) - model = result.first() - if model is None: return None - return self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) + model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) + + provider_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.ModelProvider).where( + persistence_model.ModelProvider.uuid == model.provider_uuid + ) + ) + provider = provider_result.first() + if provider: + provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider) + model_dict['provider'] = _parse_provider_api_keys(provider_dict) + + return model_dict async def update_embedding_model(self, model_uuid: str, model_data: dict) -> None: + """Update an existing embedding model""" if 'uuid' in model_data: del model_data['uuid'] + if 'provider' in model_data: + provider_data = model_data.pop('provider') + if provider_data.get('uuid'): + model_data['provider_uuid'] = provider_data['uuid'] + else: + provider_uuid = await self.ap.provider_service.find_or_create_provider( + requester=provider_data.get('requester', ''), + base_url=provider_data.get('base_url', ''), + api_keys=provider_data.get('api_keys', []), + ) + model_data['provider_uuid'] = provider_uuid + await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_model.EmbeddingModel) .where(persistence_model.EmbeddingModel.uuid == model_uuid) @@ -170,21 +296,20 @@ class EmbeddingModelsService: ) await self.ap.model_mgr.remove_embedding_model(model_uuid) - embedding_model = await self.get_embedding_model(model_uuid) - await self.ap.model_mgr.load_embedding_model(embedding_model) async def delete_embedding_model(self, model_uuid: str) -> None: + """Delete an embedding model""" await self.ap.persistence_mgr.execute_async( sqlalchemy.delete(persistence_model.EmbeddingModel).where( persistence_model.EmbeddingModel.uuid == model_uuid ) ) - await self.ap.model_mgr.remove_embedding_model(model_uuid) async def test_embedding_model(self, model_uuid: str, model_data: dict) -> None: + """Test an embedding model""" runtime_embedding_model: model_requester.RuntimeEmbeddingModel | None = None if model_uuid != '_': @@ -192,10 +317,8 @@ class EmbeddingModelsService: if model.model_entity.uuid == model_uuid: runtime_embedding_model = model break - if runtime_embedding_model is None: raise Exception('model not found') - else: runtime_embedding_model = await self.ap.model_mgr.init_runtime_embedding_model(model_data) diff --git a/src/langbot/pkg/api/http/service/provider.py b/src/langbot/pkg/api/http/service/provider.py new file mode 100644 index 00000000..eb99c092 --- /dev/null +++ b/src/langbot/pkg/api/http/service/provider.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import uuid + +import sqlalchemy + +from ....core import app +from ....entity.persistence import model as persistence_model + + +class ModelProviderService: + """Service for managing model providers""" + + ap: app.Application + + def __init__(self, ap: app.Application) -> None: + self.ap = ap + + async def get_providers(self) -> list[dict]: + """Get all providers""" + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.ModelProvider)) + providers = result.all() + providers_list = [] + for p in providers: + provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, p) + # Parse api_keys if it's a JSON string + if isinstance(provider_dict.get('api_keys'), str): + import json + + try: + provider_dict['api_keys'] = json.loads(provider_dict['api_keys']) + except Exception: + provider_dict['api_keys'] = [] + providers_list.append(provider_dict) + return providers_list + + async def get_provider(self, provider_uuid: str) -> dict | None: + """Get a single provider by UUID""" + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.ModelProvider).where( + persistence_model.ModelProvider.uuid == provider_uuid + ) + ) + provider = result.first() + if provider is None: + return None + provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider) + # Parse api_keys if it's a JSON string + if isinstance(provider_dict.get('api_keys'), str): + import json + + try: + provider_dict['api_keys'] = json.loads(provider_dict['api_keys']) + except Exception: + provider_dict['api_keys'] = [] + return provider_dict + + async def create_provider(self, provider_data: dict) -> str: + """Create a new provider""" + provider_data['uuid'] = str(uuid.uuid4()) + await self.ap.persistence_mgr.execute_async( + sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data) + ) + return provider_data['uuid'] + + async def update_provider(self, provider_uuid: str, provider_data: dict) -> None: + """Update an existing provider""" + if 'uuid' in provider_data: + del provider_data['uuid'] + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_model.ModelProvider) + .where(persistence_model.ModelProvider.uuid == provider_uuid) + .values(**provider_data) + ) + # Reload all models using this provider + await self.ap.model_mgr.load_models_from_db() + + async def delete_provider(self, provider_uuid: str) -> None: + """Delete a provider (only if no models reference it)""" + # Check if any models use this provider + llm_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.LLMModel).where( + persistence_model.LLMModel.provider_uuid == provider_uuid + ) + ) + if llm_result.first() is not None: + raise ValueError('Cannot delete provider: LLM models still reference it') + + embedding_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.EmbeddingModel).where( + persistence_model.EmbeddingModel.provider_uuid == provider_uuid + ) + ) + if embedding_result.first() is not None: + raise ValueError('Cannot delete provider: Embedding models still reference it') + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_model.ModelProvider).where( + persistence_model.ModelProvider.uuid == provider_uuid + ) + ) + + async def get_provider_model_counts(self, provider_uuid: str) -> dict: + """Get count of models using this provider""" + llm_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(sqlalchemy.func.count()) + .select_from(persistence_model.LLMModel) + .where(persistence_model.LLMModel.provider_uuid == provider_uuid) + ) + llm_count = llm_result.scalar() or 0 + + embedding_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(sqlalchemy.func.count()) + .select_from(persistence_model.EmbeddingModel) + .where(persistence_model.EmbeddingModel.provider_uuid == provider_uuid) + ) + embedding_count = embedding_result.scalar() or 0 + + return {'llm_count': llm_count, 'embedding_count': embedding_count} + + async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str: + """Find existing provider or create new one""" + # Try to find existing provider with same config + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.ModelProvider).where( + persistence_model.ModelProvider.requester == requester, + persistence_model.ModelProvider.base_url == base_url, + ) + ) + for provider in result.all(): + if sorted(provider.api_keys or []) == sorted(api_keys or []): + return provider.uuid + + # Create new provider + provider_name = requester + if base_url: + try: + from urllib.parse import urlparse + + parsed = urlparse(base_url) + provider_name = parsed.netloc or requester + except Exception: + pass + + return await self.create_provider( + { + 'name': provider_name, + 'requester': requester, + 'base_url': base_url, + 'api_keys': api_keys or [], + } + ) diff --git a/src/langbot/pkg/api/http/service/space_models.py b/src/langbot/pkg/api/http/service/space_models.py deleted file mode 100644 index 634c7e39..00000000 --- a/src/langbot/pkg/api/http/service/space_models.py +++ /dev/null @@ -1,247 +0,0 @@ -from __future__ import annotations - -import typing -import uuid as uuid_lib -import aiohttp -import sqlalchemy - -from ....core import app -from ....entity.persistence import model as persistence_model -from ....entity.persistence import user as persistence_user - - -DEFAULT_SPACE_URL = 'http://localhost:8383' - -# Space's base URL for model API requests (used for requester_config) -SPACE_API_BASE_URL = 'http://localhost:8383' - - -class SpaceModelsService: - """Service for syncing models from Space MaaS""" - - ap: app.Application - - def __init__(self, ap: app.Application) -> None: - self.ap = ap - - async def get_space_user_info(self, user_email: str) -> persistence_user.User | None: - """Get Space user info for sync operations""" - result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_user.User).where(persistence_user.User.user == user_email) - ) - result_list = result.all() - return result_list[0] if result_list else None - - async def fetch_space_models(self, space_url: str = DEFAULT_SPACE_URL) -> typing.Dict: - """Fetch available models from Space API""" - async with aiohttp.ClientSession() as session: - async with session.get(f'{space_url}/api/v1/models', params={'page_size': 100}) as response: - if response.status != 200: - raise ValueError(f'Failed to fetch models from Space: {await response.text()}') - data = await response.json() - if data.get('code') != 0: - raise ValueError(f'Failed to fetch models from Space: {data.get("msg")}') - return data.get('data', {}) - - async def sync_models_from_space( - self, user_email: str, space_url: str = DEFAULT_SPACE_URL - ) -> typing.Dict[str, typing.Any]: - """ - Sync models from Space to local database. - Returns statistics about the sync operation. - """ - # Get user info for API key - user_obj = await self.get_space_user_info(user_email) - if user_obj is None: - raise ValueError('User not found') - - if user_obj.account_type != 'space': - raise ValueError('User is not a Space account') - - if not user_obj.space_api_key: - raise ValueError('User does not have a Space API key configured') - - # Fetch models from Space - models_data = await self.fetch_space_models(space_url) - space_models = models_data.get('models', []) - - # Get existing Space models in local database - result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.source == 'space') - ) - existing_space_models = {m.space_model_id: m for m in result.all()} - - result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_model.EmbeddingModel).where( - persistence_model.EmbeddingModel.source == 'space' - ) - ) - existing_space_embedding_models = {m.space_model_id: m for m in result.all()} - - stats = {'created_llm': 0, 'updated_llm': 0, 'created_embedding': 0, 'updated_embedding': 0, 'skipped': 0} - - for model in space_models: - model_id = model.get('model_id') - category = model.get('category', '') - - if not model_id: - stats['skipped'] += 1 - continue - - if category == 'embedding': - # Handle embedding model - await self._sync_embedding_model(model, user_obj.space_api_key, existing_space_embedding_models, stats) - else: - # Handle LLM model (chat, completion, etc.) - await self._sync_llm_model(model, user_obj.space_api_key, existing_space_models, stats) - - return stats - - async def _sync_llm_model( - self, - model: typing.Dict, - api_key: str, - existing_models: typing.Dict[str, persistence_model.LLMModel], - stats: typing.Dict, - ) -> None: - """Sync a single LLM model from Space""" - model_id = model.get('model_id') - display_name = model.get('display_name', {}) - name = display_name.get('zh_Hans', display_name.get('en_US', model_id)) - description_obj = model.get('description', {}) - description = description_obj.get('zh_Hans', description_obj.get('en_US', '')) if description_obj else '' - - # Infer abilities from model capabilities - abilities = [] - supported_endpoints = model.get('supported_endpoints', []) - if 'vision' in str(supported_endpoints).lower() or 'vision' in model_id.lower(): - abilities.append('vision') - if 'function' in str(supported_endpoints).lower() or 'tool' in str(supported_endpoints).lower(): - abilities.append('function_call') - - model_data = { - 'name': name, - 'description': description[:255] if description else 'Model from Space MaaS', - 'requester': 'openai-chat-completions', # Space uses OpenAI-compatible API - 'requester_config': { - 'base-url': SPACE_API_BASE_URL, - 'args': {}, - 'timeout': 120, - }, - 'api_keys': [api_key], - 'abilities': abilities, - 'extra_args': {'model': model_id}, - 'source': 'space', - 'space_model_id': model_id, - } - - if model_id in existing_models: - # Update existing model - await self.ap.persistence_mgr.execute_async( - sqlalchemy.update(persistence_model.LLMModel) - .where(persistence_model.LLMModel.space_model_id == model_id) - .values(**model_data) - ) - stats['updated_llm'] += 1 - else: - # Create new model - model_data['uuid'] = str(uuid_lib.uuid4()) - await self.ap.persistence_mgr.execute_async( - sqlalchemy.insert(persistence_model.LLMModel).values(**model_data) - ) - stats['created_llm'] += 1 - - async def _sync_embedding_model( - self, - model: typing.Dict, - api_key: str, - existing_models: typing.Dict[str, persistence_model.EmbeddingModel], - stats: typing.Dict, - ) -> None: - """Sync a single embedding model from Space""" - model_id = model.get('model_id') - display_name = model.get('display_name', {}) - name = display_name.get('zh_Hans', display_name.get('en_US', model_id)) - description_obj = model.get('description', {}) - description = description_obj.get('zh_Hans', description_obj.get('en_US', '')) if description_obj else '' - - model_data = { - 'name': name, - 'description': description[:255] if description else 'Embedding model from Space MaaS', - 'requester': 'openai-embedding', # Space uses OpenAI-compatible API - 'requester_config': { - 'base-url': SPACE_API_BASE_URL, - 'args': {}, - 'timeout': 120, - }, - 'api_keys': [api_key], - 'extra_args': {'model': model_id}, - 'source': 'space', - 'space_model_id': model_id, - } - - if model_id in existing_models: - # Update existing model - await self.ap.persistence_mgr.execute_async( - sqlalchemy.update(persistence_model.EmbeddingModel) - .where(persistence_model.EmbeddingModel.space_model_id == model_id) - .values(**model_data) - ) - stats['updated_embedding'] += 1 - else: - # Create new model - model_data['uuid'] = str(uuid_lib.uuid4()) - await self.ap.persistence_mgr.execute_async( - sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data) - ) - stats['created_embedding'] += 1 - - async def get_space_models(self) -> typing.Dict[str, typing.List]: - """Get all synced Space models""" - llm_result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.source == 'space') - ) - embedding_result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_model.EmbeddingModel).where( - persistence_model.EmbeddingModel.source == 'space' - ) - ) - - return { - 'llm_models': [ - self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, m) for m in llm_result.all() - ], - 'embedding_models': [ - self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, m) - for m in embedding_result.all() - ], - } - - async def delete_space_models(self) -> typing.Dict[str, int]: - """Delete all synced Space models""" - # Remove from model manager first - llm_result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.source == 'space') - ) - for model in llm_result.all(): - await self.ap.model_mgr.remove_llm_model(model.uuid) - - embedding_result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_model.EmbeddingModel).where( - persistence_model.EmbeddingModel.source == 'space' - ) - ) - for model in embedding_result.all(): - await self.ap.model_mgr.remove_embedding_model(model.uuid) - - # Delete from database - llm_delete = await self.ap.persistence_mgr.execute_async( - sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.source == 'space') - ) - embedding_delete = await self.ap.persistence_mgr.execute_async( - sqlalchemy.delete(persistence_model.EmbeddingModel).where( - persistence_model.EmbeddingModel.source == 'space' - ) - ) - - return {'deleted_llm': llm_delete.rowcount, 'deleted_embedding': embedding_delete.rowcount} diff --git a/src/langbot/pkg/core/app.py b/src/langbot/pkg/core/app.py index 4b8d22a3..38f93daa 100644 --- a/src/langbot/pkg/core/app.py +++ b/src/langbot/pkg/core/app.py @@ -20,6 +20,7 @@ from ..persistence import mgr as persistencemgr from ..api.http.controller import main as http_controller from ..api.http.service import user as user_service from ..api.http.service import model as model_service +from ..api.http.service import provider as provider_service from ..api.http.service import pipeline as pipeline_service from ..api.http.service import bot as bot_service from ..api.http.service import knowledge as knowledge_service @@ -27,7 +28,6 @@ from ..api.http.service import mcp as mcp_service from ..api.http.service import apikey as apikey_service from ..api.http.service import webhook as webhook_service from ..api.http.service import external_kb as external_kb_service -from ..api.http.service import space_models as space_models_service from ..discover import engine as discover_engine from ..storage import mgr as storagemgr from ..utils import logcache @@ -119,6 +119,8 @@ class Application: embedding_models_service: model_service.EmbeddingModelsService = None + provider_service: provider_service.ModelProviderService = None + pipeline_service: pipeline_service.PipelineService = None bot_service: bot_service.BotService = None @@ -133,8 +135,6 @@ class Application: webhook_service: webhook_service.WebhookService = None - space_models_service: space_models_service.SpaceModelsService = None - def __init__(self): pass diff --git a/src/langbot/pkg/core/stages/build_app.py b/src/langbot/pkg/core/stages/build_app.py index 94a4c293..b2a054ed 100644 --- a/src/langbot/pkg/core/stages/build_app.py +++ b/src/langbot/pkg/core/stages/build_app.py @@ -17,6 +17,7 @@ from ...persistence import mgr as persistencemgr from ...api.http.controller import main as http_controller from ...api.http.service import user as user_service from ...api.http.service import model as model_service +from ...api.http.service import provider as provider_service from ...api.http.service import pipeline as pipeline_service from ...api.http.service import bot as bot_service from ...api.http.service import knowledge as knowledge_service @@ -24,7 +25,6 @@ from ...api.http.service import mcp as mcp_service from ...api.http.service import apikey as apikey_service from ...api.http.service import webhook as webhook_service from ...api.http.service import external_kb as external_kb_service -from ...api.http.service import space_models as space_models_service from ...discover import engine as discover_engine from ...storage import mgr as storagemgr from ...utils import logcache @@ -115,6 +115,9 @@ class BuildAppStage(stage.BootingStage): embedding_models_service_inst = model_service.EmbeddingModelsService(ap) ap.embedding_models_service = embedding_models_service_inst + provider_service_inst = provider_service.ModelProviderService(ap) + ap.provider_service = provider_service_inst + pipeline_service_inst = pipeline_service.PipelineService(ap) ap.pipeline_service = pipeline_service_inst @@ -136,9 +139,6 @@ class BuildAppStage(stage.BootingStage): webhook_service_inst = webhook_service.WebhookService(ap) ap.webhook_service = webhook_service_inst - space_models_service_inst = space_models_service.SpaceModelsService(ap) - ap.space_models_service = space_models_service_inst - async def runtime_disconnect_callback(connector: plugin_connector.PluginRuntimeConnector) -> None: await asyncio.sleep(3) await plugin_connector_inst.initialize() diff --git a/src/langbot/pkg/entity/persistence/model.py b/src/langbot/pkg/entity/persistence/model.py index 0dea51ff..e4459585 100644 --- a/src/langbot/pkg/entity/persistence/model.py +++ b/src/langbot/pkg/entity/persistence/model.py @@ -3,6 +3,25 @@ import sqlalchemy from .base import Base +class ModelProvider(Base): + """Model provider""" + + __tablename__ = 'model_providers' + + uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) + name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + requester = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + base_url = sqlalchemy.Column(sqlalchemy.String(512), nullable=False) + api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[]) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) + + class LLMModel(Base): """LLM model""" @@ -10,16 +29,9 @@ class LLMModel(Base): uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) - description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) - requester = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) - requester_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) - api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) + provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) abilities = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[]) extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) - # Source tracking for Space integration: 'local' or 'space' - source = sqlalchemy.Column(sqlalchemy.String(32), nullable=False, server_default='local') - # Space model ID for synced models (used to track and update synced models) - space_model_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) updated_at = sqlalchemy.Column( sqlalchemy.DateTime, @@ -30,21 +42,14 @@ class LLMModel(Base): class EmbeddingModel(Base): - """Embedding 模型""" + """Embedding model""" __tablename__ = 'embedding_models' uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) - description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) - requester = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) - requester_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) - api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) + provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) - # Source tracking for Space integration: 'local' or 'space' - source = sqlalchemy.Column(sqlalchemy.String(32), nullable=False, server_default='local') - # Space model ID for synced models (used to track and update synced models) - space_model_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) updated_at = sqlalchemy.Column( sqlalchemy.DateTime, diff --git a/src/langbot/pkg/persistence/migrations/dbm016_model_provider_refactor.py b/src/langbot/pkg/persistence/migrations/dbm016_model_provider_refactor.py new file mode 100644 index 00000000..88438409 --- /dev/null +++ b/src/langbot/pkg/persistence/migrations/dbm016_model_provider_refactor.py @@ -0,0 +1,286 @@ +import uuid as uuid_lib + +import sqlalchemy +from .. import migration + + +@migration.migration_class(16) +class DBMigrateModelProviderRefactor(migration.DBMigration): + """Refactor model structure: create providers from existing models and update references""" + + async def upgrade(self): + """Upgrade""" + # Step 1: Create model_providers table if not exists + await self._create_providers_table() + + # Step 2: Migrate existing models to use providers + await self._migrate_llm_models() + await self._migrate_embedding_models() + + # Step 3: Remove deprecated columns + await self._cleanup_columns() + + async def _create_providers_table(self): + """Create model_providers table""" + if self.ap.persistence_mgr.db.name == 'postgresql': + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text(""" + CREATE TABLE IF NOT EXISTS model_providers ( + uuid VARCHAR(255) PRIMARY KEY, + name VARCHAR(255) NOT NULL, + requester VARCHAR(255) NOT NULL, + base_url VARCHAR(512) NOT NULL, + api_keys JSONB NOT NULL DEFAULT '[]', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ) + """) + ) + else: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text(""" + CREATE TABLE IF NOT EXISTS model_providers ( + uuid VARCHAR(255) PRIMARY KEY, + name VARCHAR(255) NOT NULL, + requester VARCHAR(255) NOT NULL, + base_url VARCHAR(512) NOT NULL, + api_keys JSON NOT NULL DEFAULT '[]', + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + ) + """) + ) + + async def _migrate_llm_models(self): + """Migrate LLM models to use providers""" + llm_columns = await self._get_columns('llm_models') + + # Add provider_uuid column if not exists + if 'provider_uuid' not in llm_columns: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('ALTER TABLE llm_models ADD COLUMN provider_uuid VARCHAR(255)') + ) + + # Only migrate if old columns exist + if 'requester' not in llm_columns: + return + + # Get all LLM models with old structure + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('SELECT uuid, name, requester, requester_config, api_keys FROM llm_models') + ) + models = result.fetchall() + + # Create providers and update models + provider_cache = {} # (requester, base_url, api_keys_str) -> provider_uuid + + for model in models: + model_uuid, model_name, requester, requester_config, api_keys = model + + # Extract base_url from requester_config + base_url = '' + if requester_config: + if isinstance(requester_config, str): + import json + + requester_config = json.loads(requester_config) + base_url = requester_config.get('base_url', '') or requester_config.get('base-url', '') + + # Parse api_keys if it's a string + if isinstance(api_keys, str): + import json + + try: + api_keys = json.loads(api_keys) + except Exception: + api_keys = [] + if not api_keys: + api_keys = [] + + # Create cache key + api_keys_str = str(sorted(api_keys)) if api_keys else '[]' + cache_key = (requester, base_url, api_keys_str) + + if cache_key in provider_cache: + provider_uuid = provider_cache[cache_key] + else: + # Create new provider + provider_uuid = str(uuid_lib.uuid4()) + provider_name = f'{requester}' + if base_url: + # Extract domain for name + try: + from urllib.parse import urlparse + + parsed = urlparse(base_url) + provider_name = parsed.netloc or requester + except Exception: + pass + + import json + + api_keys_json = json.dumps(api_keys) if api_keys else '[]' + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text(""" + INSERT INTO model_providers (uuid, name, requester, base_url, api_keys) + VALUES (:uuid, :name, :requester, :base_url, :api_keys) + """), + { + 'uuid': provider_uuid, + 'name': provider_name, + 'requester': requester, + 'base_url': base_url, + 'api_keys': api_keys_json, + }, + ) + provider_cache[cache_key] = provider_uuid + + # Update model with provider_uuid + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('UPDATE llm_models SET provider_uuid = :provider_uuid WHERE uuid = :uuid'), + {'provider_uuid': provider_uuid, 'uuid': model_uuid}, + ) + + async def _migrate_embedding_models(self): + """Migrate embedding models to use providers""" + embedding_columns = await self._get_columns('embedding_models') + + # Add provider_uuid column if not exists + if 'provider_uuid' not in embedding_columns: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('ALTER TABLE embedding_models ADD COLUMN provider_uuid VARCHAR(255)') + ) + + # Only migrate if old columns exist + if 'requester' not in embedding_columns: + return + + # Get all embedding models with old structure + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('SELECT uuid, name, requester, requester_config, api_keys FROM embedding_models') + ) + models = result.fetchall() + + # Get existing providers + provider_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('SELECT uuid, requester, base_url, api_keys FROM model_providers') + ) + existing_providers = provider_result.fetchall() + + provider_cache = {} + for p in existing_providers: + p_uuid, p_requester, p_base_url, p_api_keys = p + api_keys_str = str(sorted(p_api_keys)) if p_api_keys else '[]' + provider_cache[(p_requester, p_base_url, api_keys_str)] = p_uuid + + for model in models: + model_uuid, model_name, requester, requester_config, api_keys = model + + base_url = '' + if requester_config: + if isinstance(requester_config, str): + import json + + requester_config = json.loads(requester_config) + base_url = requester_config.get('base_url', '') or requester_config.get('base-url', '') + + # Parse api_keys if it's a string + if isinstance(api_keys, str): + import json + + try: + api_keys = json.loads(api_keys) + except Exception: + api_keys = [] + if not api_keys: + api_keys = [] + + api_keys_str = str(sorted(api_keys)) if api_keys else '[]' + cache_key = (requester, base_url, api_keys_str) + + if cache_key in provider_cache: + provider_uuid = provider_cache[cache_key] + else: + provider_uuid = str(uuid_lib.uuid4()) + provider_name = f'{requester}' + if base_url: + try: + from urllib.parse import urlparse + + parsed = urlparse(base_url) + provider_name = parsed.netloc or requester + except Exception: + pass + + import json + + api_keys_json = json.dumps(api_keys) if api_keys else '[]' + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text(""" + INSERT INTO model_providers (uuid, name, requester, base_url, api_keys) + VALUES (:uuid, :name, :requester, :base_url, :api_keys) + """), + { + 'uuid': provider_uuid, + 'name': provider_name, + 'requester': requester, + 'base_url': base_url, + 'api_keys': api_keys_json, + }, + ) + provider_cache[cache_key] = provider_uuid + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('UPDATE embedding_models SET provider_uuid = :provider_uuid WHERE uuid = :uuid'), + {'provider_uuid': provider_uuid, 'uuid': model_uuid}, + ) + + async def _cleanup_columns(self): + """Remove deprecated columns from model tables""" + # SQLite doesn't support DROP COLUMN easily, so we skip for SQLite + if self.ap.persistence_mgr.db.name != 'postgresql': + return + + llm_columns = await self._get_columns('llm_models') + deprecated_llm_cols = ['requester', 'requester_config', 'api_keys', 'description', 'source', 'space_model_id'] + for col in deprecated_llm_cols: + if col in llm_columns: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text(f'ALTER TABLE llm_models DROP COLUMN IF EXISTS {col}') + ) + + embedding_columns = await self._get_columns('embedding_models') + deprecated_embedding_cols = [ + 'requester', + 'requester_config', + 'api_keys', + 'description', + 'source', + 'space_model_id', + ] + for col in deprecated_embedding_cols: + if col in embedding_columns: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text(f'ALTER TABLE embedding_models DROP COLUMN IF EXISTS {col}') + ) + + async def _get_columns(self, table_name: str) -> list: + """Get column names for a table""" + if self.ap.persistence_mgr.db.name == 'postgresql': + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.text( + f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table_name}';" + ) + ) + all_result = result.fetchall() + return [row[0] for row in all_result] + else: + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text(f'PRAGMA table_info({table_name});')) + all_result = result.fetchall() + return [row[1] for row in all_result] + + async def downgrade(self): + """Downgrade""" + pass diff --git a/src/langbot/pkg/provider/modelmgr/modelmgr.py b/src/langbot/pkg/provider/modelmgr/modelmgr.py index f0bec0a5..3c369638 100644 --- a/src/langbot/pkg/provider/modelmgr/modelmgr.py +++ b/src/langbot/pkg/provider/modelmgr/modelmgr.py @@ -10,11 +10,9 @@ from . import token from ...entity.persistence import model as persistence_model from ...entity.errors import provider as provider_errors -FETCH_MODEL_LIST_URL = 'https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list' - class ModelManager: - """模型管理器""" + """Model manager""" ap: app.Application @@ -24,7 +22,7 @@ class ModelManager: requester_components: list[engine.Component] - requester_dict: dict[str, type[requester.ProviderAPIRequester]] # cache + requester_dict: dict[str, type[requester.ProviderAPIRequester]] def __init__(self, ap: app.Application): self.ap = ap @@ -36,7 +34,6 @@ class ModelManager: async def initialize(self): self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester') - # forge requester class dict requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {} for component in self.requester_components: requester_dict[component.metadata.name] = component.get_python_component_class() @@ -46,29 +43,45 @@ class ModelManager: await self.load_models_from_db() async def load_models_from_db(self): - """从数据库加载模型""" + """Load models from database""" self.ap.logger.info('Loading models from db...') self.llm_models = [] self.embedding_models = [] - # llm models + # Load all providers first + providers_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.ModelProvider) + ) + providers = {p.uuid: p for p in providers_result.all()} + + # Load LLM models result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel)) llm_models = result.all() for llm_model in llm_models: try: - await self.load_llm_model(llm_model) + provider = providers.get(llm_model.provider_uuid) + if provider is None: + self.ap.logger.warning(f'Provider {llm_model.provider_uuid} not found for model {llm_model.uuid}') + continue + await self.load_llm_model_with_provider(llm_model, provider) except provider_errors.RequesterNotFoundError as e: self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping llm model {llm_model.uuid}') except Exception as e: self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}') - # embedding models + # Load embedding models result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel)) embedding_models = result.all() for embedding_model in embedding_models: try: - await self.load_embedding_model(embedding_model) + provider = providers.get(embedding_model.provider_uuid) + if provider is None: + self.ap.logger.warning( + f'Provider {embedding_model.provider_uuid} not found for model {embedding_model.uuid}' + ) + continue + await self.load_embedding_model_with_provider(embedding_model, provider) except provider_errors.RequesterNotFoundError as e: self.ap.logger.warning( f'Requester {e.requester_name} not found, skipping embedding model {embedding_model.uuid}' @@ -78,27 +91,33 @@ class ModelManager: async def init_runtime_llm_model( self, - model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict, + model_info: dict, ): - """初始化运行时 LLM 模型""" - if isinstance(model_info, sqlalchemy.Row): - model_info = persistence_model.LLMModel(**model_info._mapping) - elif isinstance(model_info, dict): - model_info = persistence_model.LLMModel(**model_info) + """Initialize runtime LLM model from dict (for testing)""" + provider_info = model_info.get('provider', {}) + requester_name = provider_info.get('requester', '') + base_url = provider_info.get('base_url', '') + api_keys = provider_info.get('api_keys', []) - if model_info.requester not in self.requester_dict: - raise provider_errors.RequesterNotFoundError(model_info.requester) - - requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config) + if requester_name not in self.requester_dict: + raise provider_errors.RequesterNotFoundError(requester_name) + requester_cfg = {'base_url': base_url} + requester_inst = self.requester_dict[requester_name](ap=self.ap, config=requester_cfg) await requester_inst.initialize() + # Create a temporary model entity + model_entity = persistence_model.LLMModel( + uuid=model_info.get('uuid', ''), + name=model_info.get('name', ''), + provider_uuid='', + abilities=model_info.get('abilities', []), + extra_args=model_info.get('extra_args', {}), + ) + runtime_llm_model = requester.RuntimeLLMModel( - model_entity=model_info, - token_mgr=token.TokenManager( - name=model_info.uuid, - tokens=model_info.api_keys, - ), + model_entity=model_entity, + token_mgr=token.TokenManager(name=model_entity.uuid, tokens=api_keys), requester=requester_inst, ) @@ -106,78 +125,165 @@ class ModelManager: async def init_runtime_embedding_model( self, - model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict, + model_info: dict, ): - """初始化运行时 Embedding 模型""" - if isinstance(model_info, sqlalchemy.Row): - model_info = persistence_model.EmbeddingModel(**model_info._mapping) - elif isinstance(model_info, dict): - model_info = persistence_model.EmbeddingModel(**model_info) + """Initialize runtime embedding model from dict (for testing)""" + provider_info = model_info.get('provider', {}) + requester_name = provider_info.get('requester', '') + base_url = provider_info.get('base_url', '') + api_keys = provider_info.get('api_keys', []) - if model_info.requester not in self.requester_dict: - raise provider_errors.RequesterNotFoundError(model_info.requester) - - requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config) + if requester_name not in self.requester_dict: + raise provider_errors.RequesterNotFoundError(requester_name) + requester_cfg = {'base_url': base_url} + requester_inst = self.requester_dict[requester_name](ap=self.ap, config=requester_cfg) await requester_inst.initialize() + model_entity = persistence_model.EmbeddingModel( + uuid=model_info.get('uuid', ''), + name=model_info.get('name', ''), + provider_uuid='', + extra_args=model_info.get('extra_args', {}), + ) + runtime_embedding_model = requester.RuntimeEmbeddingModel( - model_entity=model_info, - token_mgr=token.TokenManager( - name=model_info.uuid, - tokens=model_info.api_keys, - ), + model_entity=model_entity, + token_mgr=token.TokenManager(name=model_entity.uuid, tokens=api_keys), requester=requester_inst, ) return runtime_embedding_model - async def load_llm_model( + async def load_llm_model_with_provider( self, - model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict, + model_info: persistence_model.LLMModel | sqlalchemy.Row, + provider: persistence_model.ModelProvider | sqlalchemy.Row, ): - """加载 LLM 模型""" - runtime_llm_model = await self.init_runtime_llm_model(model_info) + """Load LLM model with provider info""" + if isinstance(model_info, sqlalchemy.Row): + model_info = persistence_model.LLMModel(**model_info._mapping) + if isinstance(provider, sqlalchemy.Row): + provider = persistence_model.ModelProvider(**provider._mapping) + + if provider.requester not in self.requester_dict: + raise provider_errors.RequesterNotFoundError(provider.requester) + + requester_cfg = {'base_url': provider.base_url} + requester_inst = self.requester_dict[provider.requester](ap=self.ap, config=requester_cfg) + await requester_inst.initialize() + + runtime_llm_model = requester.RuntimeLLMModel( + model_entity=model_info, + token_mgr=token.TokenManager(name=model_info.uuid, tokens=provider.api_keys or []), + requester=requester_inst, + ) + self.llm_models.append(runtime_llm_model) - async def load_embedding_model( + async def load_embedding_model_with_provider( self, - model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict, + model_info: persistence_model.EmbeddingModel | sqlalchemy.Row, + provider: persistence_model.ModelProvider | sqlalchemy.Row, ): - """加载 Embedding 模型""" - runtime_embedding_model = await self.init_runtime_embedding_model(model_info) + """Load embedding model with provider info""" + if isinstance(model_info, sqlalchemy.Row): + model_info = persistence_model.EmbeddingModel(**model_info._mapping) + if isinstance(provider, sqlalchemy.Row): + provider = persistence_model.ModelProvider(**provider._mapping) + + if provider.requester not in self.requester_dict: + raise provider_errors.RequesterNotFoundError(provider.requester) + + requester_cfg = {'base_url': provider.base_url} + requester_inst = self.requester_dict[provider.requester](ap=self.ap, config=requester_cfg) + await requester_inst.initialize() + + runtime_embedding_model = requester.RuntimeEmbeddingModel( + model_entity=model_info, + token_mgr=token.TokenManager(name=model_info.uuid, tokens=provider.api_keys or []), + requester=requester_inst, + ) + self.embedding_models.append(runtime_embedding_model) + async def load_llm_model(self, model_info: dict): + """Load LLM model from dict (with provider info)""" + provider_info = model_info.get('provider', {}) + if not provider_info: + raise ValueError('Provider info is required') + + model_entity = persistence_model.LLMModel( + uuid=model_info.get('uuid', ''), + name=model_info.get('name', ''), + provider_uuid=model_info.get('provider_uuid', ''), + abilities=model_info.get('abilities', []), + extra_args=model_info.get('extra_args', {}), + ) + + provider_entity = persistence_model.ModelProvider( + uuid=provider_info.get('uuid', ''), + name=provider_info.get('name', ''), + requester=provider_info.get('requester', ''), + base_url=provider_info.get('base_url', ''), + api_keys=provider_info.get('api_keys', []), + ) + + await self.load_llm_model_with_provider(model_entity, provider_entity) + + async def load_embedding_model(self, model_info: dict): + """Load embedding model from dict (with provider info)""" + provider_info = model_info.get('provider', {}) + if not provider_info: + raise ValueError('Provider info is required') + + model_entity = persistence_model.EmbeddingModel( + uuid=model_info.get('uuid', ''), + name=model_info.get('name', ''), + provider_uuid=model_info.get('provider_uuid', ''), + extra_args=model_info.get('extra_args', {}), + ) + + provider_entity = persistence_model.ModelProvider( + uuid=provider_info.get('uuid', ''), + name=provider_info.get('name', ''), + requester=provider_info.get('requester', ''), + base_url=provider_info.get('base_url', ''), + api_keys=provider_info.get('api_keys', []), + ) + + await self.load_embedding_model_with_provider(model_entity, provider_entity) + async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel: - """通过uuid获取 LLM 模型""" + """Get LLM model by uuid""" for model in self.llm_models: if model.model_entity.uuid == uuid: return model raise ValueError(f'LLM model {uuid} not found') async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel: - """通过uuid获取 Embedding 模型""" + """Get embedding model by uuid""" for model in self.embedding_models: if model.model_entity.uuid == uuid: return model raise ValueError(f'Embedding model {uuid} not found') async def remove_llm_model(self, model_uuid: str): - """移除 LLM 模型""" + """Remove LLM model""" for model in self.llm_models: if model.model_entity.uuid == model_uuid: self.llm_models.remove(model) return async def remove_embedding_model(self, model_uuid: str): - """移除 Embedding 模型""" + """Remove embedding model""" for model in self.embedding_models: if model.model_entity.uuid == model_uuid: self.embedding_models.remove(model) return def get_available_requesters_info(self, model_type: str) -> list[dict]: - """获取所有可用的请求器""" + """Get all available requesters""" if model_type != '': return [ component.to_plain_dict() @@ -188,14 +294,14 @@ class ModelManager: return [component.to_plain_dict() for component in self.requester_components] def get_available_requester_info_by_name(self, name: str) -> dict | None: - """通过名称获取请求器信息""" + """Get requester info by name""" for component in self.requester_components: if component.metadata.name == name: return component.to_plain_dict() return None def get_available_requester_manifest_by_name(self, name: str) -> engine.Component | None: - """通过名称获取请求器清单""" + """Get requester manifest by name""" for component in self.requester_components: if component.metadata.name == name: return component diff --git a/src/langbot/pkg/utils/constants.py b/src/langbot/pkg/utils/constants.py index 66541ba8..a82d578c 100644 --- a/src/langbot/pkg/utils/constants.py +++ b/src/langbot/pkg/utils/constants.py @@ -2,7 +2,7 @@ import langbot semantic_version = f'v{langbot.__version__}' -required_database_version = 15 +required_database_version = 16 """Tag the version of the database schema, used to check if the database needs to be migrated""" debug_mode = False diff --git a/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx b/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx index feda5e36..690d8de1 100644 --- a/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx +++ b/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx @@ -254,118 +254,36 @@ export default function DynamicFormItemComponent({ ); case DynamicFormItemType.LLM_MODEL_SELECTOR: + // Group models by provider + const groupedModels = llmModels.reduce( + (acc, model) => { + const providerName = + model.provider?.name || model.provider?.requester || 'Unknown'; + if (!acc[providerName]) acc[providerName] = []; + acc[providerName].push(model); + return acc; + }, + {} as Record, + ); + return ( ); diff --git a/web/src/app/home/components/models-dialog/ModelsDialog.tsx b/web/src/app/home/components/models-dialog/ModelsDialog.tsx index b454a8e3..2e5079a3 100644 --- a/web/src/app/home/components/models-dialog/ModelsDialog.tsx +++ b/web/src/app/home/components/models-dialog/ModelsDialog.tsx @@ -5,19 +5,19 @@ import { Plus, MessageSquareText, Cpu, - Info, - RefreshCw, - ChevronLeft, - Cloud, - HardDrive, - Lock, + ChevronDown, + ChevronRight, + Trash2, + Settings, + Sparkles, + LogIn, } from 'lucide-react'; -import { LLMCardVO } from './component/llm-card/LLMCardVO'; -import LLMCard from './component/llm-card/LLMCard'; -import LLMForm from './component/llm-form/LLMForm'; import { httpClient } from '@/app/infra/http/HttpClient'; -import { LLMModel } from '@/app/infra/entities/api'; -import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; +import { + LLMModel, + EmbeddingModel, + ModelProvider, +} from '@/app/infra/entities/api'; import { Dialog, DialogContent, @@ -25,68 +25,67 @@ import { DialogTitle, } from '@/components/ui/dialog'; import { Button } from '@/components/ui/button'; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from '@/components/ui/dropdown-menu'; +import { + Collapsible, + CollapsibleContent, + CollapsibleTrigger, +} from '@/components/ui/collapsible'; import { toast } from 'sonner'; import { useTranslation } from 'react-i18next'; import { extractI18nObject } from '@/i18n/I18nProvider'; -import { EmbeddingCardVO } from './component/embedding-card/EmbeddingCardVO'; -import EmbeddingCard from './component/embedding-card/EmbeddingCard'; -import EmbeddingForm from './component/embedding-form/EmbeddingForm'; -import { - Card, - CardContent, - CardDescription, - CardHeader, - CardTitle, -} from '@/components/ui/card'; +import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; import { Badge } from '@/components/ui/badge'; +import LLMForm from './component/llm-form/LLMForm'; +import EmbeddingForm from './component/embedding-form/EmbeddingForm'; +import ProviderForm from './component/provider-form/ProviderForm'; interface ModelsDialogProps { open: boolean; onOpenChange: (open: boolean) => void; } -type ViewMode = 'providers' | 'space' | 'local'; +const LANGBOT_MODELS_PROVIDER_NAME = 'LangBot Models'; export default function ModelsDialog({ open, onOpenChange, }: ModelsDialogProps) { const { t } = useTranslation(); - const [viewMode, setViewMode] = useState('providers'); - const [activeTab, setActiveTab] = useState('llm'); - // User account type + const [providers, setProviders] = useState([]); const [accountType, setAccountType] = useState<'local' | 'space'>('local'); + const [spaceBalance] = useState(null); - // Local models - const [localLLMList, setLocalLLMList] = useState([]); - const [localEmbeddingList, setLocalEmbeddingList] = useState< - EmbeddingCardVO[] - >([]); - - // Space models - const [spaceLLMList, setSpaceLLMList] = useState([]); - const [spaceEmbeddingList, setSpaceEmbeddingList] = useState< - EmbeddingCardVO[] - >([]); - - // Sync state - const [isSyncing, setIsSyncing] = useState(false); + // Expanded providers and their models + const [expandedProviders, setExpandedProviders] = useState>( + new Set(), + ); + const [providerModels, setProviderModels] = useState< + Record + >({}); + const [loadingProviders, setLoadingProviders] = useState>( + new Set(), + ); // Form modals - const [modalOpen, setModalOpen] = useState(false); - const [isEditForm, setIsEditForm] = useState(false); - const [nowSelectedLLM, setNowSelectedLLM] = useState(null); - const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false); - const [isEditEmbeddingForm, setIsEditEmbeddingForm] = useState(false); - const [nowSelectedEmbedding, setNowSelectedEmbedding] = - useState(null); + const [llmFormOpen, setLLMFormOpen] = useState(false); + const [embeddingFormOpen, setEmbeddingFormOpen] = useState(false); + const [providerFormOpen, setProviderFormOpen] = useState(false); + const [editingLLMId, setEditingLLMId] = useState(null); + const [editingEmbeddingId, setEditingEmbeddingId] = useState( + null, + ); + const [editingProviderId, setEditingProviderId] = useState( + null, + ); - // Requester name lists for display - const [llmRequesterNameList, setLLMRequesterNameList] = useState< - { label: string; value: string }[] - >([]); - const [embeddingRequesterNameList, setEmbeddingRequesterNameList] = useState< + const [requesterNameList, setRequesterNameList] = useState< { label: string; value: string }[] >([]); @@ -94,7 +93,7 @@ export default function ModelsDialog({ if (open) { loadUserInfo(); loadRequesterLists(); - loadAllModels(); + loadProviders(); } }, [open]); @@ -103,7 +102,6 @@ export default function ModelsDialog({ const userInfo = await httpClient.getUserInfo(); setAccountType(userInfo.account_type); } catch { - // Default to local if user info cannot be fetched setAccountType('local'); } } @@ -111,347 +109,406 @@ export default function ModelsDialog({ async function loadRequesterLists() { try { const llmRequesters = await httpClient.getProviderRequesters('llm'); - setLLMRequesterNameList( + setRequesterNameList( llmRequesters.requesters.map((item) => ({ label: extractI18nObject(item.label), value: item.name, })), ); - - const embeddingRequesters = - await httpClient.getProviderRequesters('text-embedding'); - setEmbeddingRequesterNameList( - embeddingRequesters.requesters.map((item) => ({ - label: extractI18nObject(item.label), - value: item.name, - })), - ); } catch (err) { console.error('Failed to load requester lists', err); } } - async function loadAllModels() { - await Promise.all([loadLLMModels(), loadEmbeddingModels()]); - } - - async function loadLLMModels() { + async function loadProviders() { try { - const resp = await httpClient.getProviderLLMModels(); - const localModels: LLMCardVO[] = []; - const spaceModels: LLMCardVO[] = []; - - resp.models.forEach((model: LLMModel & { source?: string }) => { - const cardVO = new LLMCardVO({ - id: model.uuid, - iconURL: httpClient.getProviderRequesterIconURL(model.requester), - name: model.name, - providerLabel: - llmRequesterNameList.find((item) => item.value === model.requester) - ?.label || model.requester.substring(0, 10), - baseURL: model.requester_config?.base_url, - abilities: model.abilities || [], - }); - - if (model.source === 'space') { - spaceModels.push(cardVO); - } else { - localModels.push(cardVO); - } - }); - - setLocalLLMList(localModels); - setSpaceLLMList(spaceModels); + const resp = await httpClient.getModelProviders(); + setProviders(resp.providers); } catch (err) { - console.error('Failed to load LLM models', err); - toast.error(t('models.getModelListError') + (err as Error).message); + console.error('Failed to load providers', err); + toast.error(t('models.loadError')); } } - async function loadEmbeddingModels() { + async function loadProviderModels(providerUuid: string) { + if (loadingProviders.has(providerUuid)) return; + + setLoadingProviders((prev) => new Set(prev).add(providerUuid)); try { - const resp = await httpClient.getProviderEmbeddingModels(); - const localModels: EmbeddingCardVO[] = []; - const spaceModels: EmbeddingCardVO[] = []; - - resp.models.forEach( - (model: { - uuid: string; - requester: string; - name: string; - requester_config?: { base_url?: string }; - source?: string; - }) => { - const cardVO = new EmbeddingCardVO({ - id: model.uuid, - iconURL: httpClient.getProviderRequesterIconURL(model.requester), - name: model.name, - providerLabel: - embeddingRequesterNameList.find( - (item) => item.value === model.requester, - )?.label || model.requester.substring(0, 10), - baseURL: model.requester_config?.base_url || '', - }); - - if (model.source === 'space') { - spaceModels.push(cardVO); - } else { - localModels.push(cardVO); - } + const [llmResp, embeddingResp] = await Promise.all([ + httpClient.getProviderLLMModels(providerUuid), + httpClient.getProviderEmbeddingModels(providerUuid), + ]); + setProviderModels((prev) => ({ + ...prev, + [providerUuid]: { + llm: llmResp.models, + embedding: embeddingResp.models, }, - ); - - setLocalEmbeddingList(localModels); - setSpaceEmbeddingList(spaceModels); + })); } catch (err) { - console.error('Failed to load embedding models', err); - toast.error(t('embedding.getModelListError') + (err as Error).message); - } - } - - async function handleSyncSpaceModels() { - setIsSyncing(true); - try { - const stats = await httpClient.syncSpaceModels(); - toast.success( - t('models.syncSuccess', { - created: stats.created_llm + stats.created_embedding, - updated: stats.updated_llm + stats.updated_embedding, - }), - ); - await loadAllModels(); - } catch (err) { - toast.error(t('models.syncError') + (err as Error).message); + console.error('Failed to load models', err); } finally { - setIsSyncing(false); + setLoadingProviders((prev) => { + const next = new Set(prev); + next.delete(providerUuid); + return next; + }); } } - function selectLLM(cardVO: LLMCardVO, isSpaceModel: boolean) { - if (isSpaceModel) { - // Space models are read-only, just show info - toast.info(t('models.spaceModelReadOnly')); - return; + function toggleProvider(providerUuid: string) { + setExpandedProviders((prev) => { + const next = new Set(prev); + if (next.has(providerUuid)) { + next.delete(providerUuid); + } else { + next.add(providerUuid); + if (!providerModels[providerUuid]) { + loadProviderModels(providerUuid); + } + } + return next; + }); + } + + function handleCreateLLM() { + setEditingLLMId(null); + setLLMFormOpen(true); + } + + function handleCreateEmbedding() { + setEditingEmbeddingId(null); + setEmbeddingFormOpen(true); + } + + function handleEditLLM(modelId: string) { + setEditingLLMId(modelId); + setLLMFormOpen(true); + } + + function handleEditEmbedding(modelId: string) { + setEditingEmbeddingId(modelId); + setEmbeddingFormOpen(true); + } + + function handleEditProvider(providerId: string) { + setEditingProviderId(providerId); + setProviderFormOpen(true); + } + + async function handleDeleteProvider(providerId: string) { + try { + await httpClient.deleteModelProvider(providerId); + toast.success(t('models.providerDeleted')); + loadProviders(); + } catch (err) { + toast.error(t('models.providerDeleteError') + (err as Error).message); } - setIsEditForm(true); - setNowSelectedLLM(cardVO); - setModalOpen(true); } - function handleCreateModelClick() { - setIsEditForm(false); - setNowSelectedLLM(null); - setModalOpen(true); - } - - function selectEmbedding(cardVO: EmbeddingCardVO, isSpaceModel: boolean) { - if (isSpaceModel) { - toast.info(t('models.spaceModelReadOnly')); - return; + async function handleDeleteLLM(modelId: string, providerUuid: string) { + try { + await httpClient.deleteProviderLLMModel(modelId); + toast.success(t('models.deleteSuccess')); + loadProviderModels(providerUuid); + loadProviders(); // Refresh counts + } catch (err) { + toast.error(t('models.deleteError') + (err as Error).message); } - setIsEditEmbeddingForm(true); - setNowSelectedEmbedding(cardVO); - setEmbeddingModalOpen(true); } - function handleCreateEmbeddingModelClick() { - setIsEditEmbeddingForm(false); - setNowSelectedEmbedding(null); - setEmbeddingModalOpen(true); + async function handleDeleteEmbedding(modelId: string, providerUuid: string) { + try { + await httpClient.deleteProviderEmbeddingModel(modelId); + toast.success(t('models.deleteSuccess')); + loadProviderModels(providerUuid); + loadProviders(); + } catch (err) { + toast.error(t('models.deleteError') + (err as Error).message); + } } - function renderProviderCards() { - const isSpaceDisabled = accountType === 'local'; + function handleSpaceLogin() { + window.location.href = '/auth/space'; + } + function getRequesterLabel(requester: string) { return ( -
- {/* Space Provider Card */} - !isSpaceDisabled && setViewMode('space')} - > - -
- -
-
-
- Space - {isSpaceDisabled && ( - - )} -
- - {isSpaceDisabled - ? t('models.spaceDisabledForLocalAccount') - : t('models.spaceProviderDescription')} - -
-
- -
- {spaceLLMList.length} LLM - - {spaceEmbeddingList.length} Embedding - -
-
-
- - {/* Local Provider Card */} - setViewMode('local')} - > - -
- -
-
- {t('models.localProvider')} - - {t('models.localProviderDescription')} - -
-
- -
- {localLLMList.length} LLM - - {localEmbeddingList.length} Embedding - -
-
-
-
+ requesterNameList.find((r) => r.value === requester)?.label || requester ); } - function renderModelList( - llmList: LLMCardVO[], - embeddingList: EmbeddingCardVO[], - isSpaceModel: boolean = false, - ) { - return ( - -
- - - - {t('llm.llmModels')} - - - - {t('embedding.embeddingModels')} - - + function maskApiKey(key: string): string { + if (!key) return ''; + if (key.length <= 8) return '****'; + return `${key.slice(0, 4)}...${key.slice(-4)}`; + } -
- {isSpaceModel ? ( - - ) : ( - + )} + {isLangBotModels && accountType === 'space' && ( + + {t('models.balance')}: {spaceBalance ?? '--'} + + )} + {!isLangBotModels && ( + <> + + {canDelete && ( + + )} + + )} +
+
+ + {isExpanded ? ( + + ) : ( + + )} + + {isExpanded + ? t('models.collapseModels') + : t('models.expandModels')} + + + + + + {isLoading ? ( +

+ {t('common.loading')}... +

+ ) : models ? ( +
+ {models.llm.map((model) => ( +
handleEditLLM(model.uuid)} + > +
+ + {model.name} + + + {t('models.chat')} + + {model.abilities?.includes('vision') && ( + + 👁 + + )} + {model.abilities?.includes('func_call') && ( + + 🔧 + + )} +
+ +
+ ))} + {models.embedding.map((model) => ( +
handleEditEmbedding(model.uuid)} + > +
+ + {model.name} + + + {t('models.embedding')} + +
+ +
+ ))} + {models.llm.length === 0 && models.embedding.length === 0 && ( +

+ {t('models.noModels')} +

+ )} +
+ ) : ( +

+ {t('models.noModels')} +

+ )} +
+
+ + + ); + } + + // Virtual LangBot Models card if not exists + function renderLangBotModelsCard() { + if (langbotProvider) { + return renderProviderCard(langbotProvider, true); + } + return ( + + +
+
+
+ +
+
+ + {LANGBOT_MODELS_PROVIDER_NAME} + +

+ {t('models.langbotModelsDescription')} +

+
+
+ {accountType !== 'space' && ( + )}
- - -
- - {activeTab === 'llm' ? ( -

- {t('llm.description')} -

- ) : ( -

- {t('embedding.description')} -

- )} -
- - - {llmList.length === 0 ? ( -
- {isSpaceModel - ? t('models.noSpaceModels') - : t('models.noLocalModels')} -
- ) : ( -
- {llmList.map((cardVO) => ( -
selectLLM(cardVO, isSpaceModel)} - className={isSpaceModel ? 'cursor-default' : 'cursor-pointer'} - > - -
- ))} -
- )} -
- - - {embeddingList.length === 0 ? ( -
- {isSpaceModel - ? t('models.noSpaceModels') - : t('models.noLocalModels')} -
- ) : ( -
- {embeddingList.map((cardVO) => ( -
selectEmbedding(cardVO, isSpaceModel)} - className={isSpaceModel ? 'cursor-default' : 'cursor-pointer'} - > - -
- ))} -
- )} -
-
+ + ); } - function getDialogTitle() { - switch (viewMode) { - case 'space': - return 'Space ' + t('models.title'); - case 'local': - return t('models.localProvider') + ' ' + t('models.title'); - default: - return t('models.title'); - } + function handleFormClose() { + setLLMFormOpen(false); + setEmbeddingFormOpen(false); + setProviderFormOpen(false); + loadProviders(); + // Refresh expanded providers + expandedProviders.forEach((uuid) => loadProviderModels(uuid)); } return ( @@ -459,89 +516,101 @@ export default function ModelsDialog({ { - if (!newOpen && (modalOpen || embeddingModalOpen)) { + if ( + !newOpen && + (llmFormOpen || embeddingFormOpen || providerFormOpen) + ) return; - } - if (!newOpen) { - setViewMode('providers'); - } onOpenChange(newOpen); }} > - + -
- {viewMode !== 'providers' && ( - - )} - {getDialogTitle()} -
+ {t('models.title')}
-
- {viewMode === 'providers' && renderProviderCards()} - {viewMode === 'space' && - renderModelList(spaceLLMList, spaceEmbeddingList, true)} - {viewMode === 'local' && - renderModelList(localLLMList, localEmbeddingList, false)} +
+ {/* Fixed LangBot Models Card */} +
{renderLangBotModelsCard()}
+ + {/* Add Model Button */} +
+ + + + + + + + {t('models.addLLMModel')} + + + + {t('models.addEmbeddingModel')} + + + +
+ + {/* Scrollable Provider List */} +
+ {otherProviders.map((p) => renderProviderCard(p))} +
- - + + - {isEditForm ? t('models.editModel') : t('models.createModel')} + {editingLLMId ? t('models.editModel') : t('models.createModel')} { - setModalOpen(false); - loadAllModels(); - }} - onFormCancel={() => { - setModalOpen(false); - }} - onLLMDeleted={() => { - setModalOpen(false); - loadAllModels(); - }} + editMode={!!editingLLMId} + initLLMId={editingLLMId || undefined} + providers={providers} + onFormSubmit={handleFormClose} + onFormCancel={() => setLLMFormOpen(false)} + onLLMDeleted={handleFormClose} /> - - + + - {isEditEmbeddingForm + {editingEmbeddingId ? t('embedding.editModel') : t('embedding.createModel')} { - setEmbeddingModalOpen(false); - loadAllModels(); - }} - onFormCancel={() => { - setEmbeddingModalOpen(false); - }} - onEmbeddingDeleted={() => { - setEmbeddingModalOpen(false); - loadAllModels(); - }} + editMode={!!editingEmbeddingId} + initEmbeddingId={editingEmbeddingId || undefined} + providers={providers} + onFormSubmit={handleFormClose} + onFormCancel={() => setEmbeddingFormOpen(false)} + onEmbeddingDeleted={handleFormClose} + /> + + + + + + + {t('models.editProvider')} + + setProviderFormOpen(false)} /> diff --git a/web/src/app/home/components/models-dialog/component/embedding-form/EmbeddingForm.tsx b/web/src/app/home/components/models-dialog/component/embedding-form/EmbeddingForm.tsx index 00399359..5bd3b66a 100644 --- a/web/src/app/home/components/models-dialog/component/embedding-form/EmbeddingForm.tsx +++ b/web/src/app/home/components/models-dialog/component/embedding-form/EmbeddingForm.tsx @@ -1,9 +1,6 @@ -import { ICreateEmbeddingField } from '../ICreateEmbeddingField'; import { useEffect, useState } from 'react'; -import { IChooseRequesterEntity } from '../ChooseRequesterEntity'; import { httpClient } from '@/app/infra/http/HttpClient'; -import { EmbeddingModel } from '@/app/infra/entities/api'; -import { UUID } from 'uuidjs'; +import { ModelProvider } from '@/app/infra/entities/api'; import { zodResolver } from '@hookform/resolvers/zod'; import { useForm } from 'react-hook-form'; @@ -42,59 +39,43 @@ import { toast } from 'sonner'; import { extractI18nObject } from '@/i18n/I18nProvider'; import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert'; import { AlertCircle } from 'lucide-react'; - -const getExtraArgSchema = (t: (key: string) => string) => - z - .object({ - key: z.string().min(1, { message: t('models.keyNameRequired') }), - type: z.enum(['string', 'number', 'boolean']), - value: z.string(), - }) - .superRefine((data, ctx) => { - if (data.type === 'number' && isNaN(Number(data.value))) { - ctx.addIssue({ - code: z.ZodIssueCode.custom, - message: t('models.mustBeValidNumber'), - path: ['value'], - }); - } - if ( - data.type === 'boolean' && - data.value !== 'true' && - data.value !== 'false' - ) { - ctx.addIssue({ - code: z.ZodIssueCode.custom, - message: t('models.mustBeTrueOrFalse'), - path: ['value'], - }); - } - }); +import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; const getFormSchema = (t: (key: string) => string) => z.object({ name: z.string().min(1, { message: t('models.modelNameRequired') }), - model_provider: z - .string() - .min(1, { message: t('models.modelProviderRequired') }), - url: z.string().optional(), - api_key: z.string().optional(), - extra_args: z.array(getExtraArgSchema(t)).optional(), + provider_uuid: z.string().optional(), + new_provider_requester: z.string().optional(), + new_provider_url: z.string().optional(), + new_provider_api_key: z.string().optional(), + extra_args: z + .array( + z.object({ + key: z.string(), + type: z.enum(['string', 'number', 'boolean']), + value: z.string(), + }), + ) + .optional(), }); +interface EmbeddingFormProps { + editMode: boolean; + initEmbeddingId?: string; + providers: ModelProvider[]; + onFormSubmit: () => void; + onFormCancel: () => void; + onEmbeddingDeleted: () => void; +} + export default function EmbeddingForm({ editMode, initEmbeddingId, + providers, onFormSubmit, onFormCancel, onEmbeddingDeleted, -}: { - editMode: boolean; - initEmbeddingId?: string; - onFormSubmit: () => void; - onFormCancel: () => void; - onEmbeddingDeleted: () => void; -}) { +}: EmbeddingFormProps) { const { t } = useTranslation(); const formSchema = getFormSchema(t); @@ -102,9 +83,10 @@ export default function EmbeddingForm({ resolver: zodResolver(formSchema), defaultValues: { name: '', - model_provider: '', - url: '', - api_key: '', + provider_uuid: '', + new_provider_requester: '', + new_provider_url: '', + new_provider_api_key: '', extra_args: [], }, }); @@ -112,54 +94,178 @@ export default function EmbeddingForm({ const [extraArgs, setExtraArgs] = useState< { key: string; type: 'string' | 'number' | 'boolean'; value: string }[] >([]); - const [showDeleteConfirmModal, setShowDeleteConfirmModal] = useState(false); - const [requesterNameList, setRequesterNameList] = useState< - IChooseRequesterEntity[] - >([]); - const [requesterDefaultURLList, setRequesterDefaultURLList] = useState< - string[] - >([]); const [modelTesting, setModelTesting] = useState(false); const [testErrorMessage, setTestErrorMessage] = useState(null); - const [currentModelProvider, setCurrentModelProvider] = useState(''); + const [providerMode, setProviderMode] = useState<'existing' | 'new'>( + 'existing', + ); + + const [requesterList, setRequesterList] = useState< + { label: string; value: string; category: string; defaultUrl: string }[] + >([]); useEffect(() => { - initEmbeddingModelFormComponent().then(() => { - if (editMode && initEmbeddingId) { - getEmbeddingConfig(initEmbeddingId).then((val) => { - form.setValue('name', val.name); - form.setValue('model_provider', val.model_provider); - setCurrentModelProvider(val.model_provider); - form.setValue('url', val.url); - form.setValue('api_key', val.api_key); - if (val.extra_args) { - const args = val.extra_args.map((arg) => { - const [key, value] = arg.split(':'); - let type: 'string' | 'number' | 'boolean' = 'string'; - if (!isNaN(Number(value))) { - type = 'number'; - } else if (value === 'true' || value === 'false') { - type = 'boolean'; - } - return { - key, - type, - value, - }; - }); - setExtraArgs(args); - form.setValue('extra_args', args); - } - }); - } else { - form.reset(); - } + loadRequesters(); + if (editMode && initEmbeddingId) { + loadModel(initEmbeddingId); + } + }, [editMode, initEmbeddingId]); + + async function loadRequesters() { + const resp = await httpClient.getProviderRequesters('text-embedding'); + setRequesterList( + resp.requesters.map((item) => ({ + label: extractI18nObject(item.label), + value: item.name, + category: item.spec.provider_category || 'manufacturer', + defaultUrl: + item.spec.config + .find((c) => c.name === 'base_url') + ?.default?.toString() || '', + })), + ); + } + + async function loadModel(id: string) { + const resp = await httpClient.getProviderEmbeddingModel(id); + const model = resp.model; + + form.setValue('name', model.name); + form.setValue('provider_uuid', model.provider_uuid); + + if (model.extra_args) { + const args = Object.entries(model.extra_args).map(([key, value]) => { + let type: 'string' | 'number' | 'boolean' = 'string'; + if (typeof value === 'number') type = 'number'; + else if (typeof value === 'boolean') type = 'boolean'; + return { key, type, value: String(value) }; + }); + setExtraArgs(args); + form.setValue('extra_args', args); + } + + setProviderMode('existing'); + } + + function handleFormSubmit(values: z.infer) { + const extraArgsObj: Record = {}; + values.extra_args?.forEach((arg) => { + if (arg.type === 'number') extraArgsObj[arg.key] = Number(arg.value); + else if (arg.type === 'boolean') + extraArgsObj[arg.key] = arg.value === 'true'; + else extraArgsObj[arg.key] = arg.value; }); - }, []); + + const modelData: Record = { + name: values.name, + extra_args: extraArgsObj, + }; + + if (providerMode === 'existing' && values.provider_uuid) { + modelData.provider_uuid = values.provider_uuid; + } else if (providerMode === 'new') { + modelData.provider = { + requester: values.new_provider_requester, + base_url: values.new_provider_url, + api_keys: values.new_provider_api_key + ? [values.new_provider_api_key] + : [], + }; + } + + if (editMode && initEmbeddingId) { + updateModel(initEmbeddingId, modelData); + } else { + createModel(modelData); + } + } + + async function createModel(data: Record) { + try { + await httpClient.createProviderEmbeddingModel(data as never); + toast.success(t('models.createSuccess')); + onFormSubmit(); + } catch (err) { + toast.error(t('models.createError') + (err as Error).message); + } + } + + async function updateModel(id: string, data: Record) { + try { + await httpClient.updateProviderEmbeddingModel(id, data as never); + toast.success(t('models.saveSuccess')); + onFormSubmit(); + } catch (err) { + toast.error(t('models.saveError') + (err as Error).message); + } + } + + async function deleteModel() { + if (!initEmbeddingId) return; + try { + await httpClient.deleteProviderEmbeddingModel(initEmbeddingId); + toast.success(t('models.deleteSuccess')); + onEmbeddingDeleted(); + } catch (err) { + toast.error(t('models.deleteError') + (err as Error).message); + } + } + + async function testModel() { + setModelTesting(true); + setTestErrorMessage(null); + + const values = form.getValues(); + const extraArgsObj: Record = {}; + values.extra_args?.forEach((arg) => { + if (arg.type === 'number') extraArgsObj[arg.key] = Number(arg.value); + else if (arg.type === 'boolean') + extraArgsObj[arg.key] = arg.value === 'true'; + else extraArgsObj[arg.key] = arg.value; + }); + + let provider: Record; + if (providerMode === 'existing' && values.provider_uuid) { + const p = providers.find((p) => p.uuid === values.provider_uuid); + provider = { + requester: p?.requester || '', + base_url: p?.base_url || '', + api_keys: p?.api_keys || [], + }; + } else { + provider = { + requester: values.new_provider_requester, + base_url: values.new_provider_url, + api_keys: values.new_provider_api_key + ? [values.new_provider_api_key] + : [], + }; + } + + try { + await httpClient.testEmbeddingModel('_', { + uuid: '', + name: values.name, + provider_uuid: '', + provider, + extra_args: extraArgsObj, + } as never); + toast.success(t('models.testSuccess')); + } catch (err) { + setTestErrorMessage((err as Error).message || t('models.testError')); + } finally { + setModelTesting(false); + } + } const addExtraArg = () => { - setExtraArgs([...extraArgs, { key: '', type: 'string', value: '' }]); + const newArgs = [ + ...extraArgs, + { key: '', type: 'string' as const, value: '' }, + ]; + setExtraArgs(newArgs); + form.setValue('extra_args', newArgs); }; const updateExtraArg = ( @@ -168,10 +274,7 @@ export default function EmbeddingForm({ value: string, ) => { const newArgs = [...extraArgs]; - newArgs[index] = { - ...newArgs[index], - [field]: value, - }; + newArgs[index] = { ...newArgs[index], [field]: value }; setExtraArgs(newArgs); form.setValue('extra_args', newArgs); }; @@ -182,167 +285,6 @@ export default function EmbeddingForm({ form.setValue('extra_args', newArgs); }; - async function initEmbeddingModelFormComponent() { - const requesterNameList = - await httpClient.getProviderRequesters('text-embedding'); - setRequesterNameList( - requesterNameList.requesters.map((item) => { - return { - label: extractI18nObject(item.label), - value: item.name, - provider_category: item.spec.provider_category || 'manufacturer', - description: extractI18nObject(item.description) || undefined, - }; - }), - ); - setRequesterDefaultURLList( - requesterNameList.requesters.map((item) => { - const config = item.spec.config; - for (let i = 0; i < config.length; i++) { - if (config[i].name == 'base_url') { - return config[i].default?.toString() || ''; - } - } - return ''; - }), - ); - } - - async function getEmbeddingConfig( - id: string, - ): Promise { - const embeddingModel = await httpClient.getProviderEmbeddingModel(id); - - const fakeExtraArgs = []; - const extraArgs = embeddingModel.model.extra_args as Record; - for (const key in extraArgs) { - fakeExtraArgs.push(`${key}:${extraArgs[key]}`); - } - return { - name: embeddingModel.model.name, - model_provider: embeddingModel.model.requester, - url: embeddingModel.model.requester_config?.base_url, - api_key: embeddingModel.model.api_keys[0], - extra_args: fakeExtraArgs, - }; - } - - function handleFormSubmit(value: z.infer) { - const extraArgsObj: Record = {}; - value.extra_args?.forEach( - (arg: { key: string; type: string; value: string }) => { - if (arg.type === 'number') { - extraArgsObj[arg.key] = Number(arg.value); - } else if (arg.type === 'boolean') { - extraArgsObj[arg.key] = arg.value === 'true'; - } else { - extraArgsObj[arg.key] = arg.value; - } - }, - ); - - const embeddingModel: EmbeddingModel = { - uuid: editMode ? initEmbeddingId || '' : UUID.generate(), - name: value.name, - description: '', - requester: value.model_provider, - requester_config: { - base_url: value.url || '', - timeout: 120, - }, - extra_args: extraArgsObj, - api_keys: value.api_key ? [value.api_key] : [], - }; - - if (editMode) { - onSaveEdit(embeddingModel).then(() => { - form.reset(); - }); - } else { - onCreateEmbedding(embeddingModel).then(() => { - form.reset(); - }); - } - } - - async function onCreateEmbedding(embeddingModel: EmbeddingModel) { - try { - await httpClient.createProviderEmbeddingModel(embeddingModel); - onFormSubmit(); - toast.success(t('models.createSuccess')); - } catch (err) { - toast.error(t('models.createError') + (err as Error).message); - } - } - - async function onSaveEdit(embeddingModel: EmbeddingModel) { - try { - await httpClient.updateProviderEmbeddingModel( - initEmbeddingId || '', - embeddingModel, - ); - onFormSubmit(); - toast.success(t('models.saveSuccess')); - } catch (err) { - toast.error(t('models.saveError') + (err as Error).message); - } - } - - function deleteModel() { - if (initEmbeddingId) { - httpClient - .deleteProviderEmbeddingModel(initEmbeddingId) - .then(() => { - onEmbeddingDeleted(); - toast.success(t('models.deleteSuccess')); - }) - .catch((err) => { - toast.error(t('models.deleteError') + err.message); - }); - } - } - - function testEmbeddingModelInForm() { - setModelTesting(true); - setTestErrorMessage(null); - const extraArgsObj: Record = {}; - form - .getValues('extra_args') - ?.forEach((arg: { key: string; type: string; value: string }) => { - if (arg.type === 'number') { - extraArgsObj[arg.key] = Number(arg.value); - } else if (arg.type === 'boolean') { - extraArgsObj[arg.key] = arg.value === 'true'; - } else { - extraArgsObj[arg.key] = arg.value; - } - }); - const apiKey = form.getValues('api_key'); - httpClient - .testEmbeddingModel('_', { - uuid: '', - name: form.getValues('name'), - description: '', - requester: form.getValues('model_provider'), - requester_config: { - base_url: form.getValues('url') ?? '', - timeout: 120, - }, - api_keys: apiKey ? [apiKey] : [], - extra_args: extraArgsObj, - }) - .then(() => { - toast.success(t('models.testSuccess')); - setTestErrorMessage(null); - }) - .catch((err: { message?: string }) => { - setTestErrorMessage(err?.message || t('models.testError')); - }) - .finally(() => { - setModelTesting(false); - }); - } - return (
-
- ( - - - {t('models.modelName')} - * - - - - - - - {t('models.modelProviderDescription')} - - - )} - /> - - ( - - - {t('models.modelProvider')} - * - - - - - {currentModelProvider && - requesterNameList.find( - (item) => item.value === currentModelProvider, - )?.description && ( - - { - requesterNameList.find( - (item) => item.value === currentModelProvider, - )?.description - } - - )} - - - )} - /> - - {!['seekdb-embedding'].includes(currentModelProvider) && ( - ( - - {t('models.requestURL')} - - - - - - )} - /> + ( + + + {t('models.modelName')} + * + + + + + + {t('models.modelProviderDescription')} + + + )} + /> - {!['ollama-chat', 'seekdb-embedding'].includes( - currentModelProvider, - ) && ( - ( - - {t('models.apiKey')} - - - - - - )} - /> - )} +
+ {t('models.provider')} + setProviderMode(v as 'existing' | 'new')} + className="mt-2" + > + + + {t('models.existingProvider')} + + {t('models.newProvider')} + - - {t('models.extraParameters')} -
- {extraArgs.map((arg, index) => ( -
- - updateExtraArg(index, 'key', e.target.value) - } - /> - - - updateExtraArg(index, 'value', e.target.value) - } - /> - -
- ))} - -
- - {t('embedding.extraParametersDescription')} - - -
+ + + + + {providers.map((p) => ( + + {p.name} ({p.base_url || 'default'}) + + ))} + + + + + )} + /> + + + + ( + + {t('models.requester')} + + + + )} + /> + + ( + + {t('models.requestURL')} + + + + + + )} + /> + + ( + + {t('models.apiKey')} + + + + + + )} + /> + +
+ + + {t('models.extraParameters')} +
+ {extraArgs.map((arg, index) => ( +
+ + updateExtraArg(index, 'key', e.target.value) + } + /> + + + updateExtraArg(index, 'value', e.target.value) + } + /> + +
+ ))} + +
+ + {t('embedding.extraParametersDescription')} + +
+ {testErrorMessage && ( @@ -612,6 +548,7 @@ export default function EmbeddingForm({ )} + {editMode && ( - - - diff --git a/web/src/app/home/components/models-dialog/component/llm-form/LLMForm.tsx b/web/src/app/home/components/models-dialog/component/llm-form/LLMForm.tsx index fe0d21b2..fb8a79ae 100644 --- a/web/src/app/home/components/models-dialog/component/llm-form/LLMForm.tsx +++ b/web/src/app/home/components/models-dialog/component/llm-form/LLMForm.tsx @@ -1,9 +1,6 @@ -import { ICreateLLMField } from '../ICreateLLMField'; import { useEffect, useState } from 'react'; -import { IChooseRequesterEntity } from '../ChooseRequesterEntity'; import { httpClient } from '@/app/infra/http/HttpClient'; -import { LLMModel } from '@/app/infra/entities/api'; -import { UUID } from 'uuidjs'; +import { ModelProvider } from '@/app/infra/entities/api'; import { zodResolver } from '@hookform/resolvers/zod'; import { useForm } from 'react-hook-form'; @@ -43,60 +40,45 @@ import { toast } from 'sonner'; import { extractI18nObject } from '@/i18n/I18nProvider'; import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert'; import { AlertCircle } from 'lucide-react'; - -const getExtraArgSchema = (t: (key: string) => string) => - z - .object({ - key: z.string().min(1, { message: t('models.keyNameRequired') }), - type: z.enum(['string', 'number', 'boolean']), - value: z.string(), - }) - .superRefine((data, ctx) => { - if (data.type === 'number' && isNaN(Number(data.value))) { - ctx.addIssue({ - code: z.ZodIssueCode.custom, - message: t('models.mustBeValidNumber'), - path: ['value'], - }); - } - if ( - data.type === 'boolean' && - data.value !== 'true' && - data.value !== 'false' - ) { - ctx.addIssue({ - code: z.ZodIssueCode.custom, - message: t('models.mustBeTrueOrFalse'), - path: ['value'], - }); - } - }); +import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; const getFormSchema = (t: (key: string) => string) => z.object({ name: z.string().min(1, { message: t('models.modelNameRequired') }), - model_provider: z - .string() - .min(1, { message: t('models.modelProviderRequired') }), - url: z.string().min(1, { message: t('models.requestURLRequired') }), - api_key: z.string().optional(), + provider_uuid: z.string().optional(), + // New provider fields + new_provider_requester: z.string().optional(), + new_provider_url: z.string().optional(), + new_provider_api_key: z.string().optional(), abilities: z.array(z.string()), - extra_args: z.array(getExtraArgSchema(t)).optional(), + extra_args: z + .array( + z.object({ + key: z.string(), + type: z.enum(['string', 'number', 'boolean']), + value: z.string(), + }), + ) + .optional(), }); +interface LLMFormProps { + editMode: boolean; + initLLMId?: string; + providers: ModelProvider[]; + onFormSubmit: () => void; + onFormCancel: () => void; + onLLMDeleted: () => void; +} + export default function LLMForm({ editMode, initLLMId, + providers, onFormSubmit, onFormCancel, onLLMDeleted, -}: { - editMode: boolean; - initLLMId?: string; - onFormSubmit: () => void; - onFormCancel: () => void; - onLLMDeleted: () => void; -}) { +}: LLMFormProps) { const { t } = useTranslation(); const formSchema = getFormSchema(t); @@ -104,9 +86,10 @@ export default function LLMForm({ resolver: zodResolver(formSchema), defaultValues: { name: '', - model_provider: '', - url: '', - api_key: '', + provider_uuid: '', + new_provider_requester: '', + new_provider_url: '', + new_provider_api_key: '', abilities: [], extra_args: [], }, @@ -115,69 +98,186 @@ export default function LLMForm({ const [extraArgs, setExtraArgs] = useState< { key: string; type: 'string' | 'number' | 'boolean'; value: string }[] >([]); - const [showDeleteConfirmModal, setShowDeleteConfirmModal] = useState(false); - const abilityOptions: { label: string; value: string }[] = [ - { - label: t('models.visionAbility'), - value: 'vision', - }, - { - label: t('models.functionCallAbility'), - value: 'func_call', - }, - ]; - const [requesterNameList, setRequesterNameList] = useState< - IChooseRequesterEntity[] - >([]); - const [requesterDefaultURLList, setRequesterDefaultURLList] = useState< - string[] - >([]); const [modelTesting, setModelTesting] = useState(false); const [testErrorMessage, setTestErrorMessage] = useState(null); - const [currentModelProvider, setCurrentModelProvider] = useState(''); + const [providerMode, setProviderMode] = useState<'existing' | 'new'>( + 'existing', + ); + + const [requesterList, setRequesterList] = useState< + { label: string; value: string; category: string; defaultUrl: string }[] + >([]); + + const abilityOptions = [ + { label: t('models.visionAbility'), value: 'vision' }, + { label: t('models.functionCallAbility'), value: 'func_call' }, + ]; useEffect(() => { - initLLMModelFormComponent().then(() => { - if (editMode && initLLMId) { - getLLMConfig(initLLMId).then((val) => { - form.setValue('name', val.name); - form.setValue('model_provider', val.model_provider); - setCurrentModelProvider(val.model_provider); - form.setValue('url', val.url); - form.setValue('api_key', val.api_key); - form.setValue( - 'abilities', - val.abilities as ('vision' | 'func_call')[], - ); - // 转换extra_args为新格式 - if (val.extra_args) { - const args = val.extra_args.map((arg) => { - const [key, value] = arg.split(':'); - let type: 'string' | 'number' | 'boolean' = 'string'; - if (!isNaN(Number(value))) { - type = 'number'; - } else if (value === 'true' || value === 'false') { - type = 'boolean'; - } - return { - key, - type, - value, - }; - }); - setExtraArgs(args); - form.setValue('extra_args', args); - } - }); - } else { - form.reset(); - } + loadRequesters(); + if (editMode && initLLMId) { + loadModel(initLLMId); + } + }, [editMode, initLLMId]); + + async function loadRequesters() { + const resp = await httpClient.getProviderRequesters('llm'); + setRequesterList( + resp.requesters.map((item) => ({ + label: extractI18nObject(item.label), + value: item.name, + category: item.spec.provider_category || 'manufacturer', + defaultUrl: + item.spec.config + .find((c) => c.name === 'base_url') + ?.default?.toString() || '', + })), + ); + } + + async function loadModel(id: string) { + const resp = await httpClient.getProviderLLMModel(id); + const model = resp.model; + + form.setValue('name', model.name); + form.setValue('provider_uuid', model.provider_uuid); + form.setValue('abilities', model.abilities || []); + + if (model.extra_args) { + const args = Object.entries(model.extra_args).map(([key, value]) => { + let type: 'string' | 'number' | 'boolean' = 'string'; + if (typeof value === 'number') type = 'number'; + else if (typeof value === 'boolean') type = 'boolean'; + return { key, type, value: String(value) }; + }); + setExtraArgs(args); + form.setValue('extra_args', args); + } + + setProviderMode('existing'); + } + + function handleFormSubmit(values: z.infer) { + const extraArgsObj: Record = {}; + values.extra_args?.forEach((arg) => { + if (arg.type === 'number') extraArgsObj[arg.key] = Number(arg.value); + else if (arg.type === 'boolean') + extraArgsObj[arg.key] = arg.value === 'true'; + else extraArgsObj[arg.key] = arg.value; }); - }, []); + + const modelData: Record = { + name: values.name, + abilities: values.abilities, + extra_args: extraArgsObj, + }; + + if (providerMode === 'existing' && values.provider_uuid) { + modelData.provider_uuid = values.provider_uuid; + } else if (providerMode === 'new') { + modelData.provider = { + requester: values.new_provider_requester, + base_url: values.new_provider_url, + api_keys: values.new_provider_api_key + ? [values.new_provider_api_key] + : [], + }; + } + + if (editMode && initLLMId) { + updateModel(initLLMId, modelData); + } else { + createModel(modelData); + } + } + + async function createModel(data: Record) { + try { + await httpClient.createProviderLLMModel(data as never); + toast.success(t('models.createSuccess')); + onFormSubmit(); + } catch (err) { + toast.error(t('models.createError') + (err as Error).message); + } + } + + async function updateModel(id: string, data: Record) { + try { + await httpClient.updateProviderLLMModel(id, data as never); + toast.success(t('models.saveSuccess')); + onFormSubmit(); + } catch (err) { + toast.error(t('models.saveError') + (err as Error).message); + } + } + + async function deleteModel() { + if (!initLLMId) return; + try { + await httpClient.deleteProviderLLMModel(initLLMId); + toast.success(t('models.deleteSuccess')); + onLLMDeleted(); + } catch (err) { + toast.error(t('models.deleteError') + (err as Error).message); + } + } + + async function testModel() { + setModelTesting(true); + setTestErrorMessage(null); + + const values = form.getValues(); + const extraArgsObj: Record = {}; + values.extra_args?.forEach((arg) => { + if (arg.type === 'number') extraArgsObj[arg.key] = Number(arg.value); + else if (arg.type === 'boolean') + extraArgsObj[arg.key] = arg.value === 'true'; + else extraArgsObj[arg.key] = arg.value; + }); + + let provider: Record; + if (providerMode === 'existing' && values.provider_uuid) { + const p = providers.find((p) => p.uuid === values.provider_uuid); + provider = { + requester: p?.requester || '', + base_url: p?.base_url || '', + api_keys: p?.api_keys || [], + }; + } else { + provider = { + requester: values.new_provider_requester, + base_url: values.new_provider_url, + api_keys: values.new_provider_api_key + ? [values.new_provider_api_key] + : [], + }; + } + + try { + await httpClient.testLLMModel('_', { + uuid: '', + name: values.name, + provider_uuid: '', + provider, + abilities: values.abilities, + extra_args: extraArgsObj, + } as never); + toast.success(t('models.testSuccess')); + } catch (err) { + setTestErrorMessage((err as Error).message || t('models.testError')); + } finally { + setModelTesting(false); + } + } const addExtraArg = () => { - setExtraArgs([...extraArgs, { key: '', type: 'string', value: '' }]); + const newArgs = [ + ...extraArgs, + { key: '', type: 'string' as const, value: '' }, + ]; + setExtraArgs(newArgs); + form.setValue('extra_args', newArgs); }; const updateExtraArg = ( @@ -186,10 +286,7 @@ export default function LLMForm({ value: string, ) => { const newArgs = [...extraArgs]; - newArgs[index] = { - ...newArgs[index], - [field]: value, - }; + newArgs[index] = { ...newArgs[index], [field]: value }; setExtraArgs(newArgs); form.setValue('extra_args', newArgs); }; @@ -200,163 +297,6 @@ export default function LLMForm({ form.setValue('extra_args', newArgs); }; - async function initLLMModelFormComponent() { - const requesterNameList = await httpClient.getProviderRequesters('llm'); - setRequesterNameList( - requesterNameList.requesters.map((item) => { - return { - label: extractI18nObject(item.label), - value: item.name, - provider_category: item.spec.provider_category || 'manufacturer', - }; - }), - ); - setRequesterDefaultURLList( - requesterNameList.requesters.map((item) => { - const config = item.spec.config; - for (let i = 0; i < config.length; i++) { - if (config[i].name == 'base_url') { - return config[i].default?.toString() || ''; - } - } - return ''; - }), - ); - } - - async function getLLMConfig(id: string): Promise { - const llmModel = await httpClient.getProviderLLMModel(id); - - const fakeExtraArgs = []; - const extraArgs = llmModel.model.extra_args as Record; - for (const key in extraArgs) { - fakeExtraArgs.push(`${key}:${extraArgs[key]}`); - } - return { - name: llmModel.model.name, - model_provider: llmModel.model.requester, - url: llmModel.model.requester_config?.base_url, - api_key: llmModel.model.api_keys[0], - abilities: llmModel.model.abilities || [], - extra_args: fakeExtraArgs, - }; - } - - function handleFormSubmit(value: z.infer) { - const extraArgsObj: Record = {}; - value.extra_args?.forEach( - (arg: { key: string; type: string; value: string }) => { - if (arg.type === 'number') { - extraArgsObj[arg.key] = Number(arg.value); - } else if (arg.type === 'boolean') { - extraArgsObj[arg.key] = arg.value === 'true'; - } else { - extraArgsObj[arg.key] = arg.value; - } - }, - ); - - const llmModel: LLMModel = { - uuid: editMode ? initLLMId || '' : UUID.generate(), - name: value.name, - description: '', - requester: value.model_provider, - requester_config: { - base_url: value.url, - timeout: 120, - }, - extra_args: extraArgsObj, - api_keys: value.api_key ? [value.api_key] : [], - abilities: value.abilities, - }; - - if (editMode) { - onSaveEdit(llmModel).then(() => { - form.reset(); - }); - } else { - onCreateLLM(llmModel).then(() => { - form.reset(); - }); - } - } - - async function onCreateLLM(llmModel: LLMModel) { - try { - await httpClient.createProviderLLMModel(llmModel); - onFormSubmit(); - toast.success(t('models.createSuccess')); - } catch (err) { - toast.error(t('models.createError') + (err as Error).message); - } - } - - async function onSaveEdit(llmModel: LLMModel) { - try { - await httpClient.updateProviderLLMModel(initLLMId || '', llmModel); - onFormSubmit(); - toast.success(t('models.saveSuccess')); - } catch (err) { - toast.error(t('models.saveError') + (err as Error).message); - } - } - - function deleteModel() { - if (initLLMId) { - httpClient - .deleteProviderLLMModel(initLLMId) - .then(() => { - onLLMDeleted(); - toast.success(t('models.deleteSuccess')); - }) - .catch((err) => { - toast.error(t('models.deleteError') + err.message); - }); - } - } - - function testLLMModelInForm() { - setModelTesting(true); - setTestErrorMessage(null); - const extraArgsObj: Record = {}; - form - .getValues('extra_args') - ?.forEach((arg: { key: string; type: string; value: string }) => { - if (arg.type === 'number') { - extraArgsObj[arg.key] = Number(arg.value); - } else if (arg.type === 'boolean') { - extraArgsObj[arg.key] = arg.value === 'true'; - } else { - extraArgsObj[arg.key] = arg.value; - } - }); - const apiKey = form.getValues('api_key'); - httpClient - .testLLMModel('_', { - uuid: '', - name: form.getValues('name'), - description: '', - requester: form.getValues('model_provider'), - requester_config: { - base_url: form.getValues('url'), - timeout: 120, - }, - api_keys: apiKey ? [apiKey] : [], - abilities: form.getValues('abilities'), - extra_args: extraArgsObj, - }) - .then(() => { - toast.success(t('models.testSuccess')); - setTestErrorMessage(null); - }) - .catch((err: { message?: string }) => { - setTestErrorMessage(err?.message || t('models.testError')); - }) - .finally(() => { - setModelTesting(false); - }); - } - return (
-
- ( - - - {t('models.modelName')} - * - - - - - - - {t('models.modelProviderDescription')} - - - )} - /> - - ( - - - {t('models.modelProvider')} - * - - - - - - - )} - /> - - ( - - - {t('models.requestURL')} - * - - - - - - - )} - /> - - {!['lmstudio-chat-completions', 'ollama-chat'].includes( - currentModelProvider, - ) && ( - ( - - {t('models.apiKey')} - - - - - - )} - /> + ( + + + {t('models.modelName')} + * + + + + + + {t('models.modelProviderDescription')} + + + )} + /> - ( - - {t('models.abilities')} -
- - {t('models.selectModelAbilities')} - -
- {abilityOptions.map((item) => ( - { - return ( - - - { - return checked - ? field.onChange([ - ...(field.value || []), - item.value, - ]) - : field.onChange( - field.value?.filter( - (value: string) => - value !== item.value, - ), - ); - }} - /> - - - {item.label} - - - ); - }} - /> - ))} - -
- )} - /> +
+ {t('models.provider')} + setProviderMode(v as 'existing' | 'new')} + className="mt-2" + > + + + {t('models.existingProvider')} + + {t('models.newProvider')} + - - {t('models.extraParameters')} -
- {extraArgs.map((arg, index) => ( -
- - updateExtraArg(index, 'key', e.target.value) - } - /> - - - updateExtraArg(index, 'value', e.target.value) - } - /> - -
- ))} - -
- - {t('llm.extraParametersDescription')} - - -
+ + + + + {providers.map((p) => ( + + {p.name} ({p.base_url || 'default'}) + + ))} + + + + + )} + /> + + + + ( + + {t('models.requester')} + + + + )} + /> + + ( + + {t('models.requestURL')} + + + + + + )} + /> + + ( + + {t('models.apiKey')} + + + + + + )} + /> + +
+ + ( + + {t('models.abilities')} + + {t('models.selectModelAbilities')} + + {abilityOptions.map((item) => ( + ( + + + { + if (checked) { + field.onChange([ + ...(field.value || []), + item.value, + ]); + } else { + field.onChange( + field.value?.filter( + (v: string) => v !== item.value, + ), + ); + } + }} + /> + + + {item.label} + + + )} + /> + ))} + + )} + /> + + + {t('models.extraParameters')} +
+ {extraArgs.map((arg, index) => ( +
+ + updateExtraArg(index, 'key', e.target.value) + } + /> + + + updateExtraArg(index, 'value', e.target.value) + } + /> + +
+ ))} + +
+ + {t('llm.extraParametersDescription')} + +
+ {testErrorMessage && ( @@ -659,6 +606,7 @@ export default function LLMForm({ )} + {editMode && ( - - - diff --git a/web/src/app/home/components/models-dialog/component/provider-form/ProviderForm.tsx b/web/src/app/home/components/models-dialog/component/provider-form/ProviderForm.tsx new file mode 100644 index 00000000..70afb369 --- /dev/null +++ b/web/src/app/home/components/models-dialog/component/provider-form/ProviderForm.tsx @@ -0,0 +1,242 @@ +import { useEffect, useState } from 'react'; +import { httpClient } from '@/app/infra/http/HttpClient'; + +import { zodResolver } from '@hookform/resolvers/zod'; +import { useForm } from 'react-hook-form'; +import { z } from 'zod'; +import { useTranslation } from 'react-i18next'; + +import { Button } from '@/components/ui/button'; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectTrigger, + SelectValue, +} from '@/components/ui/select'; +import { DialogFooter } from '@/components/ui/dialog'; +import { toast } from 'sonner'; +import { extractI18nObject } from '@/i18n/I18nProvider'; + +const getFormSchema = (t: (key: string) => string) => + z.object({ + name: z.string().min(1, { message: t('models.providerNameRequired') }), + requester: z.string().min(1, { message: t('models.requesterRequired') }), + base_url: z.string(), + api_key: z.string().optional(), + }); + +interface ProviderFormProps { + providerId?: string; + onFormSubmit: () => void; + onFormCancel: () => void; +} + +export default function ProviderForm({ + providerId, + onFormSubmit, + onFormCancel, +}: ProviderFormProps) { + const { t } = useTranslation(); + const formSchema = getFormSchema(t); + + const form = useForm>({ + resolver: zodResolver(formSchema), + defaultValues: { + name: '', + requester: '', + base_url: '', + api_key: '', + }, + }); + + const [requesterList, setRequesterList] = useState< + { label: string; value: string; category: string; defaultUrl: string }[] + >([]); + + useEffect(() => { + loadRequesters(); + if (providerId) { + loadProvider(providerId); + } + }, [providerId]); + + async function loadRequesters() { + const resp = await httpClient.getProviderRequesters('llm'); + setRequesterList( + resp.requesters.map((item) => ({ + label: extractI18nObject(item.label), + value: item.name, + category: item.spec.provider_category || 'manufacturer', + defaultUrl: + item.spec.config + .find((c) => c.name === 'base_url') + ?.default?.toString() || '', + })), + ); + } + + async function loadProvider(id: string) { + const resp = await httpClient.getModelProvider(id); + const provider = resp.provider; + + form.setValue('name', provider.name); + form.setValue('requester', provider.requester); + form.setValue('base_url', provider.base_url); + form.setValue('api_key', provider.api_keys?.[0] || ''); + } + + async function handleFormSubmit(values: z.infer) { + const data = { + name: values.name, + requester: values.requester, + base_url: values.base_url, + api_keys: values.api_key ? [values.api_key] : [], + }; + + try { + if (providerId) { + await httpClient.updateModelProvider(providerId, data); + toast.success(t('models.providerSaved')); + } else { + await httpClient.createModelProvider(data); + toast.success(t('models.providerCreated')); + } + onFormSubmit(); + } catch (err) { + toast.error(t('models.providerSaveError') + (err as Error).message); + } + } + + return ( + + + ( + + + {t('models.providerName')} + * + + + + + + + )} + /> + + ( + + + {t('models.requester')} + * + + + + + )} + /> + + ( + + {t('models.requestURL')} + + + + + + )} + /> + + ( + + {t('models.apiKey')} + + + + + + )} + /> + + + + + + + + ); +} diff --git a/web/src/app/home/knowledge/components/kb-form/KBForm.tsx b/web/src/app/home/knowledge/components/kb-form/KBForm.tsx index 6ed5173c..3f9d443f 100644 --- a/web/src/app/home/knowledge/components/kb-form/KBForm.tsx +++ b/web/src/app/home/knowledge/components/kb-form/KBForm.tsx @@ -19,16 +19,12 @@ import { SelectContent, SelectGroup, SelectItem, + SelectLabel, SelectTrigger, SelectValue, } from '@/components/ui/select'; import { KnowledgeBase, EmbeddingModel } from '@/app/infra/entities/api'; import { toast } from 'sonner'; -import { - HoverCard, - HoverCardContent, - HoverCardTrigger, -} from '@/components/ui/hover-card'; const getFormSchema = (t: (key: string) => string) => z.object({ @@ -205,90 +201,35 @@ export default function KBForm({ /> - - {embeddingModels.map((model) => ( - - - - {model.name} - - - -
-
- icon -

- {model.name} -

-
-

- {model.description} -

- {model.requester_config && ( -
- - - - - Base URL: - - {model.requester_config.base_url} -
- )} - {model.extra_args && - Object.keys(model.extra_args).length > - 0 && ( -
-
- {t('models.extraParameters')} -
-
- {Object.entries( - model.extra_args as Record< - string, - unknown - >, - ).map(([key, value]) => ( -
- - {key}: - - - {JSON.stringify(value)} - -
- ))} -
-
- )} -
-
-
- ))} -
+ {(() => { + const grouped = embeddingModels.reduce( + (acc, model) => { + const providerName = + model.provider?.name || + model.provider?.requester || + 'Unknown'; + if (!acc[providerName]) acc[providerName] = []; + acc[providerName].push(model); + return acc; + }, + {} as Record, + ); + return Object.entries(grouped).map( + ([providerName, models]) => ( + + {providerName} + {models.map((model) => ( + + {model.name} + + ))} + + ), + ); + })()}
diff --git a/web/src/app/infra/entities/api/index.ts b/web/src/app/infra/entities/api/index.ts index 9ae309ac..407bb6fc 100644 --- a/web/src/app/infra/entities/api/index.ts +++ b/web/src/app/infra/entities/api/index.ts @@ -41,20 +41,33 @@ export interface ApiRespProviderLLMModel { model: LLMModel; } -export interface LLMModel { - name: string; - description: string; +export interface ModelProvider { uuid: string; + name: string; requester: string; - requester_config: { - base_url: string; - timeout: number; - }; - extra_args?: object; + base_url: string; api_keys: string[]; + llm_count?: number; + embedding_count?: number; + created_at?: string; + updated_at?: string; +} + +export interface ApiRespModelProviders { + providers: ModelProvider[]; +} + +export interface ApiRespModelProvider { + provider: ModelProvider; +} + +export interface LLMModel { + uuid: string; + name: string; + provider_uuid: string; + provider?: ModelProvider; abilities?: string[]; - // created_at: string; - // updated_at: string; + extra_args?: object; } export interface KnowledgeBase { @@ -76,18 +89,11 @@ export interface ApiRespProviderEmbeddingModel { } export interface EmbeddingModel { - name: string; - description: string; uuid: string; - requester: string; - requester_config: { - base_url: string; - timeout: number; - }; + name: string; + provider_uuid: string; + provider?: ModelProvider; extra_args?: object; - api_keys: string[]; - // created_at: string; - // updated_at: string; } export interface ApiRespPipelines { diff --git a/web/src/app/infra/http/BackendClient.ts b/web/src/app/infra/http/BackendClient.ts index 541ec02b..69da4e9d 100644 --- a/web/src/app/infra/http/BackendClient.ts +++ b/web/src/app/infra/http/BackendClient.ts @@ -38,6 +38,9 @@ import { ExternalKnowledgeBase, ApiRespExternalKnowledgeBases, ApiRespExternalKnowledgeBase, + ApiRespModelProviders, + ApiRespModelProvider, + ModelProvider, } from '@/app/infra/entities/api'; import { Plugin } from '@/app/infra/entities/plugin'; import { GetBotLogsRequest } from '@/app/infra/http/requestParam/bots/GetBotLogsRequest'; @@ -65,7 +68,6 @@ export class BackendClient extends BaseHttpClient { public getProviderRequesterIconURL(name: string): string { if (this.instance.defaults.baseURL === '/') { - // 获取用户访问的URL const url = window.location.href; const baseURL = url.split('/').slice(0, 3).join('/'); return `${baseURL}/api/v1/provider/requesters/${name}/icon`; @@ -76,9 +78,38 @@ export class BackendClient extends BaseHttpClient { ); } + // ============ Model Providers ============ + public getModelProviders(): Promise { + return this.get('/api/v1/provider/providers'); + } + + public getModelProvider(uuid: string): Promise { + return this.get(`/api/v1/provider/providers/${uuid}`); + } + + public createModelProvider( + provider: Omit, + ): Promise<{ uuid: string }> { + return this.post('/api/v1/provider/providers', provider); + } + + public updateModelProvider( + uuid: string, + provider: Partial, + ): Promise { + return this.put(`/api/v1/provider/providers/${uuid}`, provider); + } + + public deleteModelProvider(uuid: string): Promise { + return this.delete(`/api/v1/provider/providers/${uuid}`); + } + // ============ Provider Model LLM ============ - public getProviderLLMModels(): Promise { - return this.get('/api/v1/provider/models/llm'); + public getProviderLLMModels( + providerUuid?: string, + ): Promise { + const params = providerUuid ? { provider_uuid: providerUuid } : {}; + return this.get('/api/v1/provider/models/llm', params); } public getProviderLLMModel(uuid: string): Promise { @@ -105,8 +136,11 @@ export class BackendClient extends BaseHttpClient { } // ============ Provider Model Embedding ============ - public getProviderEmbeddingModels(): Promise { - return this.get('/api/v1/provider/models/embedding'); + public getProviderEmbeddingModels( + providerUuid?: string, + ): Promise { + const params = providerUuid ? { provider_uuid: providerUuid } : {}; + return this.get('/api/v1/provider/models/embedding', params); } public getProviderEmbeddingModel( @@ -716,61 +750,4 @@ export class BackendClient extends BaseHttpClient { }> { return this.post('/api/v1/user/space/callback', { code }); } - - // ============ Space Models Sync API ============ - public syncSpaceModels(spaceUrl?: string): Promise<{ - created_llm: number; - updated_llm: number; - created_embedding: number; - updated_embedding: number; - skipped: number; - }> { - return this.post('/api/v1/space/models/sync', { space_url: spaceUrl }); - } - - public getSpaceModels(): Promise<{ - llm_models: Array<{ - uuid: string; - name: string; - description: string; - requester: string; - space_model_id: string; - source: string; - }>; - embedding_models: Array<{ - uuid: string; - name: string; - description: string; - requester: string; - space_model_id: string; - source: string; - }>; - }> { - return this.get('/api/v1/space/models'); - } - - public deleteSpaceModels(): Promise<{ - deleted_llm: number; - deleted_embedding: number; - }> { - return this.delete('/api/v1/space/models'); - } - - public getAvailableSpaceModels(spaceUrl?: string): Promise<{ - models: Array<{ - model_id: string; - display_name: { [key: string]: string }; - description: { [key: string]: string }; - category: string; - provider: string; - }>; - vendors: Array<{ - id: number; - name: string; - }>; - total: number; - }> { - const params = spaceUrl ? { space_url: spaceUrl } : {}; - return this.get('/api/v1/space/models/available', params); - } } diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index 1463d501..64cafc47 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -186,6 +186,36 @@ const enUS = { spaceModelReadOnly: 'Space models are read-only', noSpaceModels: 'No Space models. Click Sync to fetch models from Space.', noLocalModels: 'No local models. Click Create to add a model.', + // New keys for provider-based structure + addModel: 'Add Model', + addLLMModel: 'Add LLM Model', + addEmbeddingModel: 'Add Embedding Model', + provider: 'Provider', + existingProvider: 'Existing Provider', + newProvider: 'New Provider', + selectProvider: 'Select Provider', + requester: 'Requester', + selectRequester: 'Select Requester', + langbotModelsDescription: 'Cloud models powered by LangBot Space', + balance: 'Balance', + loginWithSpace: 'Login with Space', + loginToUseModels: 'Login with Space to use cloud models', + noModels: 'No models configured', + editProvider: 'Edit Provider', + providerName: 'Provider Name', + providerNameRequired: 'Provider name is required', + requesterRequired: 'Requester is required', + providerSaved: 'Provider saved', + providerCreated: 'Provider created', + providerSaveError: 'Failed to save provider: ', + providerDeleted: 'Provider deleted', + providerDeleteError: 'Failed to delete provider: ', + loadError: 'Failed to load data', + chat: 'Chat', + embedding: 'Embedding', + modelsCount: '{{count}} model(s)', + expandModels: 'Expand', + collapseModels: 'Collapse', }, bots: { title: 'Bots', diff --git a/web/src/i18n/locales/ja-JP.ts b/web/src/i18n/locales/ja-JP.ts index 2e435495..38dbea4a 100644 --- a/web/src/i18n/locales/ja-JP.ts +++ b/web/src/i18n/locales/ja-JP.ts @@ -192,6 +192,35 @@ const jaJP = { 'Space モデルがありません。同期ボタンをクリックして Space からモデルを取得してください。', noLocalModels: 'ローカルモデルがありません。作成ボタンをクリックしてモデルを追加してください。', + addModel: 'モデルを追加', + addLLMModel: 'LLMモデルを追加', + addEmbeddingModel: '埋め込みモデルを追加', + provider: 'プロバイダー', + existingProvider: '既存のプロバイダー', + newProvider: '新規プロバイダー', + selectProvider: 'プロバイダーを選択', + requester: 'リクエスター', + selectRequester: 'リクエスターを選択', + langbotModelsDescription: 'LangBot Space が提供するクラウドモデル', + balance: '残高', + loginWithSpace: 'Space でログイン', + loginToUseModels: 'Space でログインしてクラウドモデルを使用', + noModels: 'モデルがありません', + editProvider: 'プロバイダーを編集', + providerName: 'プロバイダー名', + providerNameRequired: 'プロバイダー名は必須です', + requesterRequired: 'リクエスターは必須です', + providerSaved: 'プロバイダーを保存しました', + providerCreated: 'プロバイダーを作成しました', + providerSaveError: 'プロバイダーの保存に失敗しました:', + providerDeleted: 'プロバイダーを削除しました', + providerDeleteError: 'プロバイダーの削除に失敗しました:', + loadError: 'データの読み込みに失敗しました', + chat: 'チャット', + embedding: '埋め込み', + modelsCount: '{{count}} 個のモデル', + expandModels: '展開', + collapseModels: '折りたたむ', }, bots: { title: 'ボット', diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index 1767bcf5..8c98c0f6 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -180,6 +180,36 @@ const zhHans = { spaceModelReadOnly: 'Space 模型为只读', noSpaceModels: '暂无 Space 模型。点击同步按钮从 Space 获取模型。', noLocalModels: '暂无本地模型。点击创建按钮添加模型。', + // 供应商结构新增键 + addModel: '添加模型', + addLLMModel: '添加对话模型', + addEmbeddingModel: '添加嵌入模型', + provider: '供应商', + existingProvider: '已有供应商', + newProvider: '新建供应商', + selectProvider: '选择供应商', + requester: '请求器', + selectRequester: '选择请求器', + langbotModelsDescription: 'LangBot Space 提供的云端模型', + balance: '余额', + loginWithSpace: '通过 Space 登录', + loginToUseModels: '通过 Space 登录以使用云端模型', + noModels: '暂无模型', + editProvider: '编辑供应商', + providerName: '供应商名称', + providerNameRequired: '供应商名称不能为空', + requesterRequired: '请求器不能为空', + providerSaved: '供应商已保存', + providerCreated: '供应商已创建', + providerSaveError: '保存供应商失败:', + providerDeleted: '供应商已删除', + providerDeleteError: '删除供应商失败:', + loadError: '加载数据失败', + chat: '对话', + embedding: '嵌入', + modelsCount: '{{count}} 个模型', + expandModels: '展开', + collapseModels: '收起', }, bots: { title: '机器人', diff --git a/web/src/i18n/locales/zh-Hant.ts b/web/src/i18n/locales/zh-Hant.ts index 55af4d62..ad4e5ced 100644 --- a/web/src/i18n/locales/zh-Hant.ts +++ b/web/src/i18n/locales/zh-Hant.ts @@ -180,6 +180,35 @@ const zhHant = { spaceModelReadOnly: 'Space 模型為唯讀', noSpaceModels: '暫無 Space 模型。點擊同步按鈕從 Space 取得模型。', noLocalModels: '暫無本地模型。點擊建立按鈕新增模型。', + addModel: '新增模型', + addLLMModel: '新增對話模型', + addEmbeddingModel: '新增嵌入模型', + provider: '供應商', + existingProvider: '現有供應商', + newProvider: '新供應商', + selectProvider: '選擇供應商', + requester: '請求器', + selectRequester: '選擇請求器', + langbotModelsDescription: '由 LangBot Space 提供的雲端模型', + balance: '餘額', + loginWithSpace: '使用 Space 登入', + loginToUseModels: '使用 Space 登入以使用雲端模型', + noModels: '暫無模型', + editProvider: '編輯供應商', + providerName: '供應商名稱', + providerNameRequired: '供應商名稱不能為空', + requesterRequired: '請求器不能為空', + providerSaved: '供應商已儲存', + providerCreated: '供應商已建立', + providerSaveError: '儲存供應商失敗:', + providerDeleted: '供應商已刪除', + providerDeleteError: '刪除供應商失敗:', + loadError: '載入資料失敗', + chat: '對話', + embedding: '嵌入', + modelsCount: '{{count}} 個模型', + expandModels: '展開', + collapseModels: '收起', }, bots: { title: '機器人',