feat: refactor model management to introduce provider structure, enhancing model organization and retrieval

This commit is contained in:
Junyan Qin
2025-12-26 20:27:33 +08:00
parent 455e3db28d
commit 57fcec011d
24 changed files with 2676 additions and 2106 deletions

View File

@@ -9,12 +9,15 @@ class LLMModelsRouterGroup(group.RouterGroup):
@self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
async def _() -> str:
if quart.request.method == 'GET':
provider_uuid = quart.request.args.get('provider_uuid')
if provider_uuid:
return self.success(
data={'models': await self.ap.llm_model_service.get_llm_models_by_provider(provider_uuid)}
)
return self.success(data={'models': await self.ap.llm_model_service.get_llm_models()})
elif quart.request.method == 'POST':
json_data = await quart.request.json
model_uuid = await self.ap.llm_model_service.create_llm_model(json_data)
return self.success(data={'uuid': model_uuid})
@self.route('/<model_uuid>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
@@ -52,12 +55,19 @@ class EmbeddingModelsRouterGroup(group.RouterGroup):
@self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
async def _() -> str:
if quart.request.method == 'GET':
provider_uuid = quart.request.args.get('provider_uuid')
if provider_uuid:
return self.success(
data={
'models': await self.ap.embedding_models_service.get_embedding_models_by_provider(
provider_uuid
)
}
)
return self.success(data={'models': await self.ap.embedding_models_service.get_embedding_models()})
elif quart.request.method == 'POST':
json_data = await quart.request.json
model_uuid = await self.ap.embedding_models_service.create_embedding_model(json_data)
return self.success(data={'uuid': model_uuid})
@self.route('/<model_uuid>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)

View File

@@ -0,0 +1,45 @@
import quart
from ... import group
@group.group_class('models/providers', '/api/v1/provider/providers')
class ModelProvidersRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
async def _() -> str:
if quart.request.method == 'GET':
providers = await self.ap.provider_service.get_providers()
# Add model counts
for provider in providers:
counts = await self.ap.provider_service.get_provider_model_counts(provider['uuid'])
provider['llm_count'] = counts['llm_count']
provider['embedding_count'] = counts['embedding_count']
return self.success(data={'providers': providers})
elif quart.request.method == 'POST':
json_data = await quart.request.json
provider_uuid = await self.ap.provider_service.create_provider(json_data)
return self.success(data={'uuid': provider_uuid})
@self.route(
'/<provider_uuid>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY
)
async def _(provider_uuid: str) -> str:
if quart.request.method == 'GET':
provider = await self.ap.provider_service.get_provider(provider_uuid)
if provider is None:
return self.http_status(404, -1, 'provider not found')
counts = await self.ap.provider_service.get_provider_model_counts(provider_uuid)
provider['llm_count'] = counts['llm_count']
provider['embedding_count'] = counts['embedding_count']
return self.success(data={'provider': provider})
elif quart.request.method == 'PUT':
json_data = await quart.request.json
await self.ap.provider_service.update_provider(provider_uuid, json_data)
return self.success()
elif quart.request.method == 'DELETE':
try:
await self.ap.provider_service.delete_provider(provider_uuid)
return self.success()
except ValueError as e:
return self.http_status(400, -1, str(e))

View File

@@ -1,52 +0,0 @@
import quart
from .. import group
DEFAULT_SPACE_URL = 'https://space.langbot.app'
@group.group_class('space', '/api/v1/space')
class SpaceRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/models/sync', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _(user_email: str) -> str:
"""Sync models from Space MaaS to local database"""
json_data = await quart.request.json or {}
space_url = json_data.get('space_url', DEFAULT_SPACE_URL)
try:
stats = await self.ap.space_models_service.sync_models_from_space(user_email, space_url)
return self.success(data=stats)
except ValueError as e:
return self.fail(1, str(e))
except Exception as e:
return self.fail(2, f'Failed to sync models: {str(e)}')
@self.route('/models', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _(user_email: str) -> str:
"""Get all synced Space models"""
if quart.request.method == 'GET':
try:
models = await self.ap.space_models_service.get_space_models()
return self.success(data=models)
except Exception as e:
return self.fail(1, f'Failed to get Space models: {str(e)}')
elif quart.request.method == 'DELETE':
try:
stats = await self.ap.space_models_service.delete_space_models()
return self.success(data=stats)
except Exception as e:
return self.fail(1, f'Failed to delete Space models: {str(e)}')
@self.route('/models/available', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _(user_email: str) -> str:
"""Get available models from Space (preview before sync)"""
try:
space_url = quart.request.args.get('space_url', DEFAULT_SPACE_URL)
models_data = await self.ap.space_models_service.fetch_space_models(space_url)
return self.success(data=models_data)
except ValueError as e:
return self.fail(1, str(e))
except Exception as e:
return self.fail(2, f'Failed to fetch available models: {str(e)}')

View File

@@ -11,6 +11,18 @@ from ....entity.persistence import pipeline as persistence_pipeline
from ....provider.modelmgr import requester as model_requester
def _parse_provider_api_keys(provider_dict: dict) -> dict:
"""Parse api_keys if it's a JSON string"""
if isinstance(provider_dict.get('api_keys'), str):
import json
try:
provider_dict['api_keys'] = json.loads(provider_dict['api_keys'])
except Exception:
provider_dict['api_keys'] = []
return provider_dict
class LLMModelsService:
ap: app.Application
@@ -18,29 +30,64 @@ class LLMModelsService:
self.ap = ap
async def get_llm_models(self, include_secret: bool = True) -> list[dict]:
"""Get all LLM models with provider info"""
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
models = result.all()
masked_columns = []
if not include_secret:
masked_columns = ['api_keys']
# Get all providers for lookup
providers_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider)
)
providers = {p.uuid: p for p in providers_result.all()}
return [
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model, masked_columns)
for model in models
]
models_list = []
for model in models:
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
provider = providers.get(model.provider_uuid)
if provider:
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
provider_dict = _parse_provider_api_keys(provider_dict)
if not include_secret:
provider_dict['api_keys'] = ['***'] * len(provider_dict.get('api_keys', []))
model_dict['provider'] = provider_dict
models_list.append(model_dict)
return models_list
async def get_llm_models_by_provider(self, provider_uuid: str) -> list[dict]:
"""Get LLM models by provider UUID"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(
persistence_model.LLMModel.provider_uuid == provider_uuid
)
)
models = result.all()
return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, m) for m in models]
async def create_llm_model(self, model_data: dict) -> str:
"""Create a new LLM model"""
model_data['uuid'] = str(uuid.uuid4())
# Handle provider creation if needed
if 'provider' in model_data:
provider_data = model_data.pop('provider')
if provider_data.get('uuid'):
model_data['provider_uuid'] = provider_data['uuid']
else:
# Create new provider
provider_uuid = await self.ap.provider_service.find_or_create_provider(
requester=provider_data.get('requester', ''),
base_url=provider_data.get('base_url', ''),
api_keys=provider_data.get('api_keys', []),
)
model_data['provider_uuid'] = provider_uuid
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_model.LLMModel).values(**model_data))
llm_model = await self.get_llm_model(model_data['uuid'])
await self.ap.model_mgr.load_llm_model(llm_model)
# check if default pipeline has no model bound
# Check if default pipeline has no model bound
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.is_default == True
@@ -56,21 +103,47 @@ class LLMModelsService:
return model_data['uuid']
async def get_llm_model(self, model_uuid: str) -> dict | None:
"""Get a single LLM model with provider info"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
)
model = result.first()
if model is None:
return None
return self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
# Get provider
provider_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider).where(
persistence_model.ModelProvider.uuid == model.provider_uuid
)
)
provider = provider_result.first()
if provider:
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
return model_dict
async def update_llm_model(self, model_uuid: str, model_data: dict) -> None:
"""Update an existing LLM model"""
if 'uuid' in model_data:
del model_data['uuid']
# Handle provider update if needed
if 'provider' in model_data:
provider_data = model_data.pop('provider')
if provider_data.get('uuid'):
model_data['provider_uuid'] = provider_data['uuid']
else:
provider_uuid = await self.ap.provider_service.find_or_create_provider(
requester=provider_data.get('requester', ''),
base_url=provider_data.get('base_url', ''),
api_keys=provider_data.get('api_keys', []),
)
model_data['provider_uuid'] = provider_uuid
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.LLMModel)
.where(persistence_model.LLMModel.uuid == model_uuid)
@@ -78,19 +151,18 @@ class LLMModelsService:
)
await self.ap.model_mgr.remove_llm_model(model_uuid)
llm_model = await self.get_llm_model(model_uuid)
await self.ap.model_mgr.load_llm_model(llm_model)
async def delete_llm_model(self, model_uuid: str) -> None:
"""Delete an LLM model"""
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
)
await self.ap.model_mgr.remove_llm_model(model_uuid)
async def test_llm_model(self, model_uuid: str, model_data: dict) -> None:
"""Test an LLM model"""
runtime_llm_model: model_requester.RuntimeLLMModel | None = None
if model_uuid != '_':
@@ -98,18 +170,11 @@ class LLMModelsService:
if model.model_entity.uuid == model_uuid:
runtime_llm_model = model
break
if runtime_llm_model is None:
raise Exception('model not found')
else:
runtime_llm_model = await self.ap.model_mgr.init_runtime_llm_model(model_data)
# Mon Nov 10 2025: Commented for some providers may not support thinking parameter
# # 有些模型厂商默认开启了思考功能,测试容易延迟
# extra_args = model_data.get('extra_args', {})
# if not extra_args or 'thinking' not in extra_args:
# extra_args['thinking'] = {'type': 'disabled'}
extra_args = model_data.get('extra_args', {})
await runtime_llm_model.requester.invoke_llm(
query=None,
@@ -127,42 +192,103 @@ class EmbeddingModelsService:
self.ap = ap
async def get_embedding_models(self) -> list[dict]:
"""Get all embedding models with provider info"""
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel))
models = result.all()
return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) for model in models]
providers_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider)
)
providers = {p.uuid: p for p in providers_result.all()}
models_list = []
for model in models:
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model)
provider = providers.get(model.provider_uuid)
if provider:
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
models_list.append(model_dict)
return models_list
async def get_embedding_models_by_provider(self, provider_uuid: str) -> list[dict]:
"""Get embedding models by provider UUID"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.provider_uuid == provider_uuid
)
)
models = result.all()
return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, m) for m in models]
async def create_embedding_model(self, model_data: dict) -> str:
"""Create a new embedding model"""
model_data['uuid'] = str(uuid.uuid4())
if 'provider' in model_data:
provider_data = model_data.pop('provider')
if provider_data.get('uuid'):
model_data['provider_uuid'] = provider_data['uuid']
else:
provider_uuid = await self.ap.provider_service.find_or_create_provider(
requester=provider_data.get('requester', ''),
base_url=provider_data.get('base_url', ''),
api_keys=provider_data.get('api_keys', []),
)
model_data['provider_uuid'] = provider_uuid
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data)
)
embedding_model = await self.get_embedding_model(model_data['uuid'])
await self.ap.model_mgr.load_embedding_model(embedding_model)
return model_data['uuid']
async def get_embedding_model(self, model_uuid: str) -> dict | None:
"""Get a single embedding model with provider info"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.uuid == model_uuid
)
)
model = result.first()
if model is None:
return None
return self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model)
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model)
provider_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider).where(
persistence_model.ModelProvider.uuid == model.provider_uuid
)
)
provider = provider_result.first()
if provider:
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
return model_dict
async def update_embedding_model(self, model_uuid: str, model_data: dict) -> None:
"""Update an existing embedding model"""
if 'uuid' in model_data:
del model_data['uuid']
if 'provider' in model_data:
provider_data = model_data.pop('provider')
if provider_data.get('uuid'):
model_data['provider_uuid'] = provider_data['uuid']
else:
provider_uuid = await self.ap.provider_service.find_or_create_provider(
requester=provider_data.get('requester', ''),
base_url=provider_data.get('base_url', ''),
api_keys=provider_data.get('api_keys', []),
)
model_data['provider_uuid'] = provider_uuid
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.EmbeddingModel)
.where(persistence_model.EmbeddingModel.uuid == model_uuid)
@@ -170,21 +296,20 @@ class EmbeddingModelsService:
)
await self.ap.model_mgr.remove_embedding_model(model_uuid)
embedding_model = await self.get_embedding_model(model_uuid)
await self.ap.model_mgr.load_embedding_model(embedding_model)
async def delete_embedding_model(self, model_uuid: str) -> None:
"""Delete an embedding model"""
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.uuid == model_uuid
)
)
await self.ap.model_mgr.remove_embedding_model(model_uuid)
async def test_embedding_model(self, model_uuid: str, model_data: dict) -> None:
"""Test an embedding model"""
runtime_embedding_model: model_requester.RuntimeEmbeddingModel | None = None
if model_uuid != '_':
@@ -192,10 +317,8 @@ class EmbeddingModelsService:
if model.model_entity.uuid == model_uuid:
runtime_embedding_model = model
break
if runtime_embedding_model is None:
raise Exception('model not found')
else:
runtime_embedding_model = await self.ap.model_mgr.init_runtime_embedding_model(model_data)

View File

@@ -0,0 +1,152 @@
from __future__ import annotations
import uuid
import sqlalchemy
from ....core import app
from ....entity.persistence import model as persistence_model
class ModelProviderService:
"""Service for managing model providers"""
ap: app.Application
def __init__(self, ap: app.Application) -> None:
self.ap = ap
async def get_providers(self) -> list[dict]:
"""Get all providers"""
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.ModelProvider))
providers = result.all()
providers_list = []
for p in providers:
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, p)
# Parse api_keys if it's a JSON string
if isinstance(provider_dict.get('api_keys'), str):
import json
try:
provider_dict['api_keys'] = json.loads(provider_dict['api_keys'])
except Exception:
provider_dict['api_keys'] = []
providers_list.append(provider_dict)
return providers_list
async def get_provider(self, provider_uuid: str) -> dict | None:
"""Get a single provider by UUID"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider).where(
persistence_model.ModelProvider.uuid == provider_uuid
)
)
provider = result.first()
if provider is None:
return None
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
# Parse api_keys if it's a JSON string
if isinstance(provider_dict.get('api_keys'), str):
import json
try:
provider_dict['api_keys'] = json.loads(provider_dict['api_keys'])
except Exception:
provider_dict['api_keys'] = []
return provider_dict
async def create_provider(self, provider_data: dict) -> str:
"""Create a new provider"""
provider_data['uuid'] = str(uuid.uuid4())
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data)
)
return provider_data['uuid']
async def update_provider(self, provider_uuid: str, provider_data: dict) -> None:
"""Update an existing provider"""
if 'uuid' in provider_data:
del provider_data['uuid']
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.ModelProvider)
.where(persistence_model.ModelProvider.uuid == provider_uuid)
.values(**provider_data)
)
# Reload all models using this provider
await self.ap.model_mgr.load_models_from_db()
async def delete_provider(self, provider_uuid: str) -> None:
"""Delete a provider (only if no models reference it)"""
# Check if any models use this provider
llm_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(
persistence_model.LLMModel.provider_uuid == provider_uuid
)
)
if llm_result.first() is not None:
raise ValueError('Cannot delete provider: LLM models still reference it')
embedding_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.provider_uuid == provider_uuid
)
)
if embedding_result.first() is not None:
raise ValueError('Cannot delete provider: Embedding models still reference it')
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.ModelProvider).where(
persistence_model.ModelProvider.uuid == provider_uuid
)
)
async def get_provider_model_counts(self, provider_uuid: str) -> dict:
"""Get count of models using this provider"""
llm_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(sqlalchemy.func.count())
.select_from(persistence_model.LLMModel)
.where(persistence_model.LLMModel.provider_uuid == provider_uuid)
)
llm_count = llm_result.scalar() or 0
embedding_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(sqlalchemy.func.count())
.select_from(persistence_model.EmbeddingModel)
.where(persistence_model.EmbeddingModel.provider_uuid == provider_uuid)
)
embedding_count = embedding_result.scalar() or 0
return {'llm_count': llm_count, 'embedding_count': embedding_count}
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
"""Find existing provider or create new one"""
# Try to find existing provider with same config
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider).where(
persistence_model.ModelProvider.requester == requester,
persistence_model.ModelProvider.base_url == base_url,
)
)
for provider in result.all():
if sorted(provider.api_keys or []) == sorted(api_keys or []):
return provider.uuid
# Create new provider
provider_name = requester
if base_url:
try:
from urllib.parse import urlparse
parsed = urlparse(base_url)
provider_name = parsed.netloc or requester
except Exception:
pass
return await self.create_provider(
{
'name': provider_name,
'requester': requester,
'base_url': base_url,
'api_keys': api_keys or [],
}
)

View File

@@ -1,247 +0,0 @@
from __future__ import annotations
import typing
import uuid as uuid_lib
import aiohttp
import sqlalchemy
from ....core import app
from ....entity.persistence import model as persistence_model
from ....entity.persistence import user as persistence_user
DEFAULT_SPACE_URL = 'http://localhost:8383'
# Space's base URL for model API requests (used for requester_config)
SPACE_API_BASE_URL = 'http://localhost:8383'
class SpaceModelsService:
"""Service for syncing models from Space MaaS"""
ap: app.Application
def __init__(self, ap: app.Application) -> None:
self.ap = ap
async def get_space_user_info(self, user_email: str) -> persistence_user.User | None:
"""Get Space user info for sync operations"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_user.User).where(persistence_user.User.user == user_email)
)
result_list = result.all()
return result_list[0] if result_list else None
async def fetch_space_models(self, space_url: str = DEFAULT_SPACE_URL) -> typing.Dict:
"""Fetch available models from Space API"""
async with aiohttp.ClientSession() as session:
async with session.get(f'{space_url}/api/v1/models', params={'page_size': 100}) as response:
if response.status != 200:
raise ValueError(f'Failed to fetch models from Space: {await response.text()}')
data = await response.json()
if data.get('code') != 0:
raise ValueError(f'Failed to fetch models from Space: {data.get("msg")}')
return data.get('data', {})
async def sync_models_from_space(
self, user_email: str, space_url: str = DEFAULT_SPACE_URL
) -> typing.Dict[str, typing.Any]:
"""
Sync models from Space to local database.
Returns statistics about the sync operation.
"""
# Get user info for API key
user_obj = await self.get_space_user_info(user_email)
if user_obj is None:
raise ValueError('User not found')
if user_obj.account_type != 'space':
raise ValueError('User is not a Space account')
if not user_obj.space_api_key:
raise ValueError('User does not have a Space API key configured')
# Fetch models from Space
models_data = await self.fetch_space_models(space_url)
space_models = models_data.get('models', [])
# Get existing Space models in local database
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.source == 'space')
)
existing_space_models = {m.space_model_id: m for m in result.all()}
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.source == 'space'
)
)
existing_space_embedding_models = {m.space_model_id: m for m in result.all()}
stats = {'created_llm': 0, 'updated_llm': 0, 'created_embedding': 0, 'updated_embedding': 0, 'skipped': 0}
for model in space_models:
model_id = model.get('model_id')
category = model.get('category', '')
if not model_id:
stats['skipped'] += 1
continue
if category == 'embedding':
# Handle embedding model
await self._sync_embedding_model(model, user_obj.space_api_key, existing_space_embedding_models, stats)
else:
# Handle LLM model (chat, completion, etc.)
await self._sync_llm_model(model, user_obj.space_api_key, existing_space_models, stats)
return stats
async def _sync_llm_model(
self,
model: typing.Dict,
api_key: str,
existing_models: typing.Dict[str, persistence_model.LLMModel],
stats: typing.Dict,
) -> None:
"""Sync a single LLM model from Space"""
model_id = model.get('model_id')
display_name = model.get('display_name', {})
name = display_name.get('zh_Hans', display_name.get('en_US', model_id))
description_obj = model.get('description', {})
description = description_obj.get('zh_Hans', description_obj.get('en_US', '')) if description_obj else ''
# Infer abilities from model capabilities
abilities = []
supported_endpoints = model.get('supported_endpoints', [])
if 'vision' in str(supported_endpoints).lower() or 'vision' in model_id.lower():
abilities.append('vision')
if 'function' in str(supported_endpoints).lower() or 'tool' in str(supported_endpoints).lower():
abilities.append('function_call')
model_data = {
'name': name,
'description': description[:255] if description else 'Model from Space MaaS',
'requester': 'openai-chat-completions', # Space uses OpenAI-compatible API
'requester_config': {
'base-url': SPACE_API_BASE_URL,
'args': {},
'timeout': 120,
},
'api_keys': [api_key],
'abilities': abilities,
'extra_args': {'model': model_id},
'source': 'space',
'space_model_id': model_id,
}
if model_id in existing_models:
# Update existing model
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.LLMModel)
.where(persistence_model.LLMModel.space_model_id == model_id)
.values(**model_data)
)
stats['updated_llm'] += 1
else:
# Create new model
model_data['uuid'] = str(uuid_lib.uuid4())
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.LLMModel).values(**model_data)
)
stats['created_llm'] += 1
async def _sync_embedding_model(
self,
model: typing.Dict,
api_key: str,
existing_models: typing.Dict[str, persistence_model.EmbeddingModel],
stats: typing.Dict,
) -> None:
"""Sync a single embedding model from Space"""
model_id = model.get('model_id')
display_name = model.get('display_name', {})
name = display_name.get('zh_Hans', display_name.get('en_US', model_id))
description_obj = model.get('description', {})
description = description_obj.get('zh_Hans', description_obj.get('en_US', '')) if description_obj else ''
model_data = {
'name': name,
'description': description[:255] if description else 'Embedding model from Space MaaS',
'requester': 'openai-embedding', # Space uses OpenAI-compatible API
'requester_config': {
'base-url': SPACE_API_BASE_URL,
'args': {},
'timeout': 120,
},
'api_keys': [api_key],
'extra_args': {'model': model_id},
'source': 'space',
'space_model_id': model_id,
}
if model_id in existing_models:
# Update existing model
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.EmbeddingModel)
.where(persistence_model.EmbeddingModel.space_model_id == model_id)
.values(**model_data)
)
stats['updated_embedding'] += 1
else:
# Create new model
model_data['uuid'] = str(uuid_lib.uuid4())
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data)
)
stats['created_embedding'] += 1
async def get_space_models(self) -> typing.Dict[str, typing.List]:
"""Get all synced Space models"""
llm_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.source == 'space')
)
embedding_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.source == 'space'
)
)
return {
'llm_models': [
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, m) for m in llm_result.all()
],
'embedding_models': [
self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, m)
for m in embedding_result.all()
],
}
async def delete_space_models(self) -> typing.Dict[str, int]:
"""Delete all synced Space models"""
# Remove from model manager first
llm_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.source == 'space')
)
for model in llm_result.all():
await self.ap.model_mgr.remove_llm_model(model.uuid)
embedding_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.source == 'space'
)
)
for model in embedding_result.all():
await self.ap.model_mgr.remove_embedding_model(model.uuid)
# Delete from database
llm_delete = await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.source == 'space')
)
embedding_delete = await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.source == 'space'
)
)
return {'deleted_llm': llm_delete.rowcount, 'deleted_embedding': embedding_delete.rowcount}

View File

@@ -20,6 +20,7 @@ from ..persistence import mgr as persistencemgr
from ..api.http.controller import main as http_controller
from ..api.http.service import user as user_service
from ..api.http.service import model as model_service
from ..api.http.service import provider as provider_service
from ..api.http.service import pipeline as pipeline_service
from ..api.http.service import bot as bot_service
from ..api.http.service import knowledge as knowledge_service
@@ -27,7 +28,6 @@ from ..api.http.service import mcp as mcp_service
from ..api.http.service import apikey as apikey_service
from ..api.http.service import webhook as webhook_service
from ..api.http.service import external_kb as external_kb_service
from ..api.http.service import space_models as space_models_service
from ..discover import engine as discover_engine
from ..storage import mgr as storagemgr
from ..utils import logcache
@@ -119,6 +119,8 @@ class Application:
embedding_models_service: model_service.EmbeddingModelsService = None
provider_service: provider_service.ModelProviderService = None
pipeline_service: pipeline_service.PipelineService = None
bot_service: bot_service.BotService = None
@@ -133,8 +135,6 @@ class Application:
webhook_service: webhook_service.WebhookService = None
space_models_service: space_models_service.SpaceModelsService = None
def __init__(self):
pass

View File

@@ -17,6 +17,7 @@ from ...persistence import mgr as persistencemgr
from ...api.http.controller import main as http_controller
from ...api.http.service import user as user_service
from ...api.http.service import model as model_service
from ...api.http.service import provider as provider_service
from ...api.http.service import pipeline as pipeline_service
from ...api.http.service import bot as bot_service
from ...api.http.service import knowledge as knowledge_service
@@ -24,7 +25,6 @@ from ...api.http.service import mcp as mcp_service
from ...api.http.service import apikey as apikey_service
from ...api.http.service import webhook as webhook_service
from ...api.http.service import external_kb as external_kb_service
from ...api.http.service import space_models as space_models_service
from ...discover import engine as discover_engine
from ...storage import mgr as storagemgr
from ...utils import logcache
@@ -115,6 +115,9 @@ class BuildAppStage(stage.BootingStage):
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
ap.embedding_models_service = embedding_models_service_inst
provider_service_inst = provider_service.ModelProviderService(ap)
ap.provider_service = provider_service_inst
pipeline_service_inst = pipeline_service.PipelineService(ap)
ap.pipeline_service = pipeline_service_inst
@@ -136,9 +139,6 @@ class BuildAppStage(stage.BootingStage):
webhook_service_inst = webhook_service.WebhookService(ap)
ap.webhook_service = webhook_service_inst
space_models_service_inst = space_models_service.SpaceModelsService(ap)
ap.space_models_service = space_models_service_inst
async def runtime_disconnect_callback(connector: plugin_connector.PluginRuntimeConnector) -> None:
await asyncio.sleep(3)
await plugin_connector_inst.initialize()

View File

@@ -3,6 +3,25 @@ import sqlalchemy
from .base import Base
class ModelProvider(Base):
"""Model provider"""
__tablename__ = 'model_providers'
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
requester = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
base_url = sqlalchemy.Column(sqlalchemy.String(512), nullable=False)
api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[])
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,
server_default=sqlalchemy.func.now(),
onupdate=sqlalchemy.func.now(),
)
class LLMModel(Base):
"""LLM model"""
@@ -10,16 +29,9 @@ class LLMModel(Base):
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
requester = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
requester_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
abilities = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[])
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
# Source tracking for Space integration: 'local' or 'space'
source = sqlalchemy.Column(sqlalchemy.String(32), nullable=False, server_default='local')
# Space model ID for synced models (used to track and update synced models)
space_model_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
@@ -30,21 +42,14 @@ class LLMModel(Base):
class EmbeddingModel(Base):
"""Embedding 模型"""
"""Embedding model"""
__tablename__ = 'embedding_models'
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
requester = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
requester_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
# Source tracking for Space integration: 'local' or 'space'
source = sqlalchemy.Column(sqlalchemy.String(32), nullable=False, server_default='local')
# Space model ID for synced models (used to track and update synced models)
space_model_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,

View File

@@ -0,0 +1,286 @@
import uuid as uuid_lib
import sqlalchemy
from .. import migration
@migration.migration_class(16)
class DBMigrateModelProviderRefactor(migration.DBMigration):
"""Refactor model structure: create providers from existing models and update references"""
async def upgrade(self):
"""Upgrade"""
# Step 1: Create model_providers table if not exists
await self._create_providers_table()
# Step 2: Migrate existing models to use providers
await self._migrate_llm_models()
await self._migrate_embedding_models()
# Step 3: Remove deprecated columns
await self._cleanup_columns()
async def _create_providers_table(self):
"""Create model_providers table"""
if self.ap.persistence_mgr.db.name == 'postgresql':
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text("""
CREATE TABLE IF NOT EXISTS model_providers (
uuid VARCHAR(255) PRIMARY KEY,
name VARCHAR(255) NOT NULL,
requester VARCHAR(255) NOT NULL,
base_url VARCHAR(512) NOT NULL,
api_keys JSONB NOT NULL DEFAULT '[]',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)
""")
)
else:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text("""
CREATE TABLE IF NOT EXISTS model_providers (
uuid VARCHAR(255) PRIMARY KEY,
name VARCHAR(255) NOT NULL,
requester VARCHAR(255) NOT NULL,
base_url VARCHAR(512) NOT NULL,
api_keys JSON NOT NULL DEFAULT '[]',
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
)
""")
)
async def _migrate_llm_models(self):
"""Migrate LLM models to use providers"""
llm_columns = await self._get_columns('llm_models')
# Add provider_uuid column if not exists
if 'provider_uuid' not in llm_columns:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('ALTER TABLE llm_models ADD COLUMN provider_uuid VARCHAR(255)')
)
# Only migrate if old columns exist
if 'requester' not in llm_columns:
return
# Get all LLM models with old structure
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('SELECT uuid, name, requester, requester_config, api_keys FROM llm_models')
)
models = result.fetchall()
# Create providers and update models
provider_cache = {} # (requester, base_url, api_keys_str) -> provider_uuid
for model in models:
model_uuid, model_name, requester, requester_config, api_keys = model
# Extract base_url from requester_config
base_url = ''
if requester_config:
if isinstance(requester_config, str):
import json
requester_config = json.loads(requester_config)
base_url = requester_config.get('base_url', '') or requester_config.get('base-url', '')
# Parse api_keys if it's a string
if isinstance(api_keys, str):
import json
try:
api_keys = json.loads(api_keys)
except Exception:
api_keys = []
if not api_keys:
api_keys = []
# Create cache key
api_keys_str = str(sorted(api_keys)) if api_keys else '[]'
cache_key = (requester, base_url, api_keys_str)
if cache_key in provider_cache:
provider_uuid = provider_cache[cache_key]
else:
# Create new provider
provider_uuid = str(uuid_lib.uuid4())
provider_name = f'{requester}'
if base_url:
# Extract domain for name
try:
from urllib.parse import urlparse
parsed = urlparse(base_url)
provider_name = parsed.netloc or requester
except Exception:
pass
import json
api_keys_json = json.dumps(api_keys) if api_keys else '[]'
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text("""
INSERT INTO model_providers (uuid, name, requester, base_url, api_keys)
VALUES (:uuid, :name, :requester, :base_url, :api_keys)
"""),
{
'uuid': provider_uuid,
'name': provider_name,
'requester': requester,
'base_url': base_url,
'api_keys': api_keys_json,
},
)
provider_cache[cache_key] = provider_uuid
# Update model with provider_uuid
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('UPDATE llm_models SET provider_uuid = :provider_uuid WHERE uuid = :uuid'),
{'provider_uuid': provider_uuid, 'uuid': model_uuid},
)
async def _migrate_embedding_models(self):
"""Migrate embedding models to use providers"""
embedding_columns = await self._get_columns('embedding_models')
# Add provider_uuid column if not exists
if 'provider_uuid' not in embedding_columns:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('ALTER TABLE embedding_models ADD COLUMN provider_uuid VARCHAR(255)')
)
# Only migrate if old columns exist
if 'requester' not in embedding_columns:
return
# Get all embedding models with old structure
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('SELECT uuid, name, requester, requester_config, api_keys FROM embedding_models')
)
models = result.fetchall()
# Get existing providers
provider_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('SELECT uuid, requester, base_url, api_keys FROM model_providers')
)
existing_providers = provider_result.fetchall()
provider_cache = {}
for p in existing_providers:
p_uuid, p_requester, p_base_url, p_api_keys = p
api_keys_str = str(sorted(p_api_keys)) if p_api_keys else '[]'
provider_cache[(p_requester, p_base_url, api_keys_str)] = p_uuid
for model in models:
model_uuid, model_name, requester, requester_config, api_keys = model
base_url = ''
if requester_config:
if isinstance(requester_config, str):
import json
requester_config = json.loads(requester_config)
base_url = requester_config.get('base_url', '') or requester_config.get('base-url', '')
# Parse api_keys if it's a string
if isinstance(api_keys, str):
import json
try:
api_keys = json.loads(api_keys)
except Exception:
api_keys = []
if not api_keys:
api_keys = []
api_keys_str = str(sorted(api_keys)) if api_keys else '[]'
cache_key = (requester, base_url, api_keys_str)
if cache_key in provider_cache:
provider_uuid = provider_cache[cache_key]
else:
provider_uuid = str(uuid_lib.uuid4())
provider_name = f'{requester}'
if base_url:
try:
from urllib.parse import urlparse
parsed = urlparse(base_url)
provider_name = parsed.netloc or requester
except Exception:
pass
import json
api_keys_json = json.dumps(api_keys) if api_keys else '[]'
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text("""
INSERT INTO model_providers (uuid, name, requester, base_url, api_keys)
VALUES (:uuid, :name, :requester, :base_url, :api_keys)
"""),
{
'uuid': provider_uuid,
'name': provider_name,
'requester': requester,
'base_url': base_url,
'api_keys': api_keys_json,
},
)
provider_cache[cache_key] = provider_uuid
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('UPDATE embedding_models SET provider_uuid = :provider_uuid WHERE uuid = :uuid'),
{'provider_uuid': provider_uuid, 'uuid': model_uuid},
)
async def _cleanup_columns(self):
"""Remove deprecated columns from model tables"""
# SQLite doesn't support DROP COLUMN easily, so we skip for SQLite
if self.ap.persistence_mgr.db.name != 'postgresql':
return
llm_columns = await self._get_columns('llm_models')
deprecated_llm_cols = ['requester', 'requester_config', 'api_keys', 'description', 'source', 'space_model_id']
for col in deprecated_llm_cols:
if col in llm_columns:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text(f'ALTER TABLE llm_models DROP COLUMN IF EXISTS {col}')
)
embedding_columns = await self._get_columns('embedding_models')
deprecated_embedding_cols = [
'requester',
'requester_config',
'api_keys',
'description',
'source',
'space_model_id',
]
for col in deprecated_embedding_cols:
if col in embedding_columns:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text(f'ALTER TABLE embedding_models DROP COLUMN IF EXISTS {col}')
)
async def _get_columns(self, table_name: str) -> list:
"""Get column names for a table"""
if self.ap.persistence_mgr.db.name == 'postgresql':
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.text(
f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table_name}';"
)
)
all_result = result.fetchall()
return [row[0] for row in all_result]
else:
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text(f'PRAGMA table_info({table_name});'))
all_result = result.fetchall()
return [row[1] for row in all_result]
async def downgrade(self):
"""Downgrade"""
pass

View File

@@ -10,11 +10,9 @@ from . import token
from ...entity.persistence import model as persistence_model
from ...entity.errors import provider as provider_errors
FETCH_MODEL_LIST_URL = 'https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list'
class ModelManager:
"""模型管理器"""
"""Model manager"""
ap: app.Application
@@ -24,7 +22,7 @@ class ModelManager:
requester_components: list[engine.Component]
requester_dict: dict[str, type[requester.ProviderAPIRequester]] # cache
requester_dict: dict[str, type[requester.ProviderAPIRequester]]
def __init__(self, ap: app.Application):
self.ap = ap
@@ -36,7 +34,6 @@ class ModelManager:
async def initialize(self):
self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester')
# forge requester class dict
requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {}
for component in self.requester_components:
requester_dict[component.metadata.name] = component.get_python_component_class()
@@ -46,29 +43,45 @@ class ModelManager:
await self.load_models_from_db()
async def load_models_from_db(self):
"""从数据库加载模型"""
"""Load models from database"""
self.ap.logger.info('Loading models from db...')
self.llm_models = []
self.embedding_models = []
# llm models
# Load all providers first
providers_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider)
)
providers = {p.uuid: p for p in providers_result.all()}
# Load LLM models
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
llm_models = result.all()
for llm_model in llm_models:
try:
await self.load_llm_model(llm_model)
provider = providers.get(llm_model.provider_uuid)
if provider is None:
self.ap.logger.warning(f'Provider {llm_model.provider_uuid} not found for model {llm_model.uuid}')
continue
await self.load_llm_model_with_provider(llm_model, provider)
except provider_errors.RequesterNotFoundError as e:
self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping llm model {llm_model.uuid}')
except Exception as e:
self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}')
# embedding models
# Load embedding models
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel))
embedding_models = result.all()
for embedding_model in embedding_models:
try:
await self.load_embedding_model(embedding_model)
provider = providers.get(embedding_model.provider_uuid)
if provider is None:
self.ap.logger.warning(
f'Provider {embedding_model.provider_uuid} not found for model {embedding_model.uuid}'
)
continue
await self.load_embedding_model_with_provider(embedding_model, provider)
except provider_errors.RequesterNotFoundError as e:
self.ap.logger.warning(
f'Requester {e.requester_name} not found, skipping embedding model {embedding_model.uuid}'
@@ -78,27 +91,33 @@ class ModelManager:
async def init_runtime_llm_model(
self,
model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict,
model_info: dict,
):
"""初始化运行时 LLM 模型"""
if isinstance(model_info, sqlalchemy.Row):
model_info = persistence_model.LLMModel(**model_info._mapping)
elif isinstance(model_info, dict):
model_info = persistence_model.LLMModel(**model_info)
"""Initialize runtime LLM model from dict (for testing)"""
provider_info = model_info.get('provider', {})
requester_name = provider_info.get('requester', '')
base_url = provider_info.get('base_url', '')
api_keys = provider_info.get('api_keys', [])
if model_info.requester not in self.requester_dict:
raise provider_errors.RequesterNotFoundError(model_info.requester)
requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config)
if requester_name not in self.requester_dict:
raise provider_errors.RequesterNotFoundError(requester_name)
requester_cfg = {'base_url': base_url}
requester_inst = self.requester_dict[requester_name](ap=self.ap, config=requester_cfg)
await requester_inst.initialize()
# Create a temporary model entity
model_entity = persistence_model.LLMModel(
uuid=model_info.get('uuid', ''),
name=model_info.get('name', ''),
provider_uuid='',
abilities=model_info.get('abilities', []),
extra_args=model_info.get('extra_args', {}),
)
runtime_llm_model = requester.RuntimeLLMModel(
model_entity=model_info,
token_mgr=token.TokenManager(
name=model_info.uuid,
tokens=model_info.api_keys,
),
model_entity=model_entity,
token_mgr=token.TokenManager(name=model_entity.uuid, tokens=api_keys),
requester=requester_inst,
)
@@ -106,78 +125,165 @@ class ModelManager:
async def init_runtime_embedding_model(
self,
model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict,
model_info: dict,
):
"""初始化运行时 Embedding 模型"""
if isinstance(model_info, sqlalchemy.Row):
model_info = persistence_model.EmbeddingModel(**model_info._mapping)
elif isinstance(model_info, dict):
model_info = persistence_model.EmbeddingModel(**model_info)
"""Initialize runtime embedding model from dict (for testing)"""
provider_info = model_info.get('provider', {})
requester_name = provider_info.get('requester', '')
base_url = provider_info.get('base_url', '')
api_keys = provider_info.get('api_keys', [])
if model_info.requester not in self.requester_dict:
raise provider_errors.RequesterNotFoundError(model_info.requester)
requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config)
if requester_name not in self.requester_dict:
raise provider_errors.RequesterNotFoundError(requester_name)
requester_cfg = {'base_url': base_url}
requester_inst = self.requester_dict[requester_name](ap=self.ap, config=requester_cfg)
await requester_inst.initialize()
model_entity = persistence_model.EmbeddingModel(
uuid=model_info.get('uuid', ''),
name=model_info.get('name', ''),
provider_uuid='',
extra_args=model_info.get('extra_args', {}),
)
runtime_embedding_model = requester.RuntimeEmbeddingModel(
model_entity=model_info,
token_mgr=token.TokenManager(
name=model_info.uuid,
tokens=model_info.api_keys,
),
model_entity=model_entity,
token_mgr=token.TokenManager(name=model_entity.uuid, tokens=api_keys),
requester=requester_inst,
)
return runtime_embedding_model
async def load_llm_model(
async def load_llm_model_with_provider(
self,
model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict,
model_info: persistence_model.LLMModel | sqlalchemy.Row,
provider: persistence_model.ModelProvider | sqlalchemy.Row,
):
"""加载 LLM 模型"""
runtime_llm_model = await self.init_runtime_llm_model(model_info)
"""Load LLM model with provider info"""
if isinstance(model_info, sqlalchemy.Row):
model_info = persistence_model.LLMModel(**model_info._mapping)
if isinstance(provider, sqlalchemy.Row):
provider = persistence_model.ModelProvider(**provider._mapping)
if provider.requester not in self.requester_dict:
raise provider_errors.RequesterNotFoundError(provider.requester)
requester_cfg = {'base_url': provider.base_url}
requester_inst = self.requester_dict[provider.requester](ap=self.ap, config=requester_cfg)
await requester_inst.initialize()
runtime_llm_model = requester.RuntimeLLMModel(
model_entity=model_info,
token_mgr=token.TokenManager(name=model_info.uuid, tokens=provider.api_keys or []),
requester=requester_inst,
)
self.llm_models.append(runtime_llm_model)
async def load_embedding_model(
async def load_embedding_model_with_provider(
self,
model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict,
model_info: persistence_model.EmbeddingModel | sqlalchemy.Row,
provider: persistence_model.ModelProvider | sqlalchemy.Row,
):
"""加载 Embedding 模型"""
runtime_embedding_model = await self.init_runtime_embedding_model(model_info)
"""Load embedding model with provider info"""
if isinstance(model_info, sqlalchemy.Row):
model_info = persistence_model.EmbeddingModel(**model_info._mapping)
if isinstance(provider, sqlalchemy.Row):
provider = persistence_model.ModelProvider(**provider._mapping)
if provider.requester not in self.requester_dict:
raise provider_errors.RequesterNotFoundError(provider.requester)
requester_cfg = {'base_url': provider.base_url}
requester_inst = self.requester_dict[provider.requester](ap=self.ap, config=requester_cfg)
await requester_inst.initialize()
runtime_embedding_model = requester.RuntimeEmbeddingModel(
model_entity=model_info,
token_mgr=token.TokenManager(name=model_info.uuid, tokens=provider.api_keys or []),
requester=requester_inst,
)
self.embedding_models.append(runtime_embedding_model)
async def load_llm_model(self, model_info: dict):
"""Load LLM model from dict (with provider info)"""
provider_info = model_info.get('provider', {})
if not provider_info:
raise ValueError('Provider info is required')
model_entity = persistence_model.LLMModel(
uuid=model_info.get('uuid', ''),
name=model_info.get('name', ''),
provider_uuid=model_info.get('provider_uuid', ''),
abilities=model_info.get('abilities', []),
extra_args=model_info.get('extra_args', {}),
)
provider_entity = persistence_model.ModelProvider(
uuid=provider_info.get('uuid', ''),
name=provider_info.get('name', ''),
requester=provider_info.get('requester', ''),
base_url=provider_info.get('base_url', ''),
api_keys=provider_info.get('api_keys', []),
)
await self.load_llm_model_with_provider(model_entity, provider_entity)
async def load_embedding_model(self, model_info: dict):
"""Load embedding model from dict (with provider info)"""
provider_info = model_info.get('provider', {})
if not provider_info:
raise ValueError('Provider info is required')
model_entity = persistence_model.EmbeddingModel(
uuid=model_info.get('uuid', ''),
name=model_info.get('name', ''),
provider_uuid=model_info.get('provider_uuid', ''),
extra_args=model_info.get('extra_args', {}),
)
provider_entity = persistence_model.ModelProvider(
uuid=provider_info.get('uuid', ''),
name=provider_info.get('name', ''),
requester=provider_info.get('requester', ''),
base_url=provider_info.get('base_url', ''),
api_keys=provider_info.get('api_keys', []),
)
await self.load_embedding_model_with_provider(model_entity, provider_entity)
async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel:
"""通过uuid获取 LLM 模型"""
"""Get LLM model by uuid"""
for model in self.llm_models:
if model.model_entity.uuid == uuid:
return model
raise ValueError(f'LLM model {uuid} not found')
async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel:
"""通过uuid获取 Embedding 模型"""
"""Get embedding model by uuid"""
for model in self.embedding_models:
if model.model_entity.uuid == uuid:
return model
raise ValueError(f'Embedding model {uuid} not found')
async def remove_llm_model(self, model_uuid: str):
"""移除 LLM 模型"""
"""Remove LLM model"""
for model in self.llm_models:
if model.model_entity.uuid == model_uuid:
self.llm_models.remove(model)
return
async def remove_embedding_model(self, model_uuid: str):
"""移除 Embedding 模型"""
"""Remove embedding model"""
for model in self.embedding_models:
if model.model_entity.uuid == model_uuid:
self.embedding_models.remove(model)
return
def get_available_requesters_info(self, model_type: str) -> list[dict]:
"""获取所有可用的请求器"""
"""Get all available requesters"""
if model_type != '':
return [
component.to_plain_dict()
@@ -188,14 +294,14 @@ class ModelManager:
return [component.to_plain_dict() for component in self.requester_components]
def get_available_requester_info_by_name(self, name: str) -> dict | None:
"""通过名称获取请求器信息"""
"""Get requester info by name"""
for component in self.requester_components:
if component.metadata.name == name:
return component.to_plain_dict()
return None
def get_available_requester_manifest_by_name(self, name: str) -> engine.Component | None:
"""通过名称获取请求器清单"""
"""Get requester manifest by name"""
for component in self.requester_components:
if component.metadata.name == name:
return component

View File

@@ -2,7 +2,7 @@ import langbot
semantic_version = f'v{langbot.__version__}'
required_database_version = 15
required_database_version = 16
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
debug_mode = False

View File

@@ -254,118 +254,36 @@ export default function DynamicFormItemComponent({
);
case DynamicFormItemType.LLM_MODEL_SELECTOR:
// Group models by provider
const groupedModels = llmModels.reduce(
(acc, model) => {
const providerName =
model.provider?.name || model.provider?.requester || 'Unknown';
if (!acc[providerName]) acc[providerName] = [];
acc[providerName].push(model);
return acc;
},
{} as Record<string, LLMModel[]>,
);
return (
<Select value={field.value} onValueChange={field.onChange}>
<SelectTrigger className="bg-[#ffffff] dark:bg-[#2a2a2e]">
<SelectValue placeholder={t('models.selectModel')} />
</SelectTrigger>
<SelectContent>
<SelectGroup>
{llmModels.map((model) => (
<HoverCard key={model.uuid} openDelay={0} closeDelay={0}>
<HoverCardTrigger asChild>
<SelectItem value={model.uuid}>{model.name}</SelectItem>
</HoverCardTrigger>
<HoverCardContent
className="w-80 data-[state=open]:animate-none data-[state=closed]:animate-none"
align="end"
side="right"
sideOffset={10}
>
<div className="space-y-2">
<div className="flex items-center gap-2">
<img
src={httpClient.getProviderRequesterIconURL(
model.requester,
)}
alt="icon"
className="w-8 h-8 rounded-[8%]"
/>
<h4 className="font-medium">{model.name}</h4>
</div>
<p className="text-sm text-muted-foreground">
{model.description}
</p>
{model.requester_config && (
<div className="flex items-center gap-1 text-xs">
<svg
className="w-4 h-4 text-gray-500"
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M13.0607 8.11097L14.4749 9.52518C17.2086 12.2589 17.2086 16.691 14.4749 19.4247L14.1214 19.7782C11.3877 22.5119 6.95555 22.5119 4.22188 19.7782C1.48821 17.0446 1.48821 12.6124 4.22188 9.87874L5.6361 11.293C3.68348 13.2456 3.68348 16.4114 5.6361 18.364C7.58872 20.3166 10.7545 20.3166 12.7072 18.364L13.0607 18.0105C15.0133 16.0578 15.0133 12.892 13.0607 10.9394L11.6465 9.52518L13.0607 8.11097ZM19.7782 14.1214L18.364 12.7072C20.3166 10.7545 20.3166 7.58872 18.364 5.6361C16.4114 3.68348 13.2456 3.68348 11.293 5.6361L10.9394 5.98965C8.98678 7.94227 8.98678 11.1081 10.9394 13.0607L12.3536 14.4749L10.9394 15.8891L9.52518 14.4749C6.79151 11.7413 6.79151 7.30911 9.52518 4.57544L9.87874 4.22188C12.6124 1.48821 17.0446 1.48821 19.7782 4.22188C22.5119 6.95555 22.5119 11.3877 19.7782 14.1214Z"></path>
</svg>
<span className="font-semibold">Base URL</span>
{model.requester_config.base_url}
</div>
)}
{model.abilities && model.abilities.length > 0 && (
<div className="flex flex-wrap gap-1">
{model.abilities.map((ability) => (
<div
key={ability}
className="flex items-center gap-1 px-2 py-1 text-xs rounded-full bg-blue-100 text-blue-600"
>
{ability === 'vision' && (
<svg
className="w-3 h-3"
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M12 2C17.5228 2 22 6.47715 22 12C22 17.5228 17.5228 22 12 22C6.47715 22 2 17.5228 2 12C2 6.47715 6.47715 2 12 2ZM12 4C7.58172 4 4 7.58172 4 12C4 16.4183 7.58172 20 12 20C16.4183 20 20 16.4183 20 12C20 7.58172 16.4183 4 12 4ZM12 7C14.7614 7 17 9.23858 17 12C17 14.7614 14.7614 17 12 17C9.23858 17 7 14.7614 7 12C7 11.4872 7.07719 10.9925 7.22057 10.5268C7.61175 11.3954 8.48527 12 9.5 12C10.8807 12 12 10.8807 12 9.5C12 8.48527 11.3954 7.61175 10.5269 7.21995C10.9925 7.07719 11.4872 7 12 7Z"></path>
</svg>
)}
{ability === 'func_call' && (
<svg
className="w-3 h-3"
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M5.32943 3.27158C6.56252 2.8332 7.9923 3.10749 8.97927 4.09446C10.1002 5.21537 10.3019 6.90741 9.5843 8.23385L20.293 18.9437L18.8788 20.3579L8.16982 9.64875C6.84325 10.3669 5.15069 10.1654 4.02952 9.04421C3.04227 8.05696 2.7681 6.62665 3.20701 5.39332L5.44373 7.63C6.02952 8.21578 6.97927 8.21578 7.56505 7.63C8.15084 7.04421 8.15084 6.09446 7.56505 5.50868L5.32943 3.27158ZM15.6968 5.15512L18.8788 3.38736L20.293 4.80157L18.5252 7.98355L16.7574 8.3371L14.6361 10.4584L13.2219 9.04421L15.3432 6.92289L15.6968 5.15512ZM8.97927 13.2868L10.3935 14.7011L5.09018 20.0044C4.69966 20.3949 4.06649 20.3949 3.67597 20.0044C3.31334 19.6417 3.28744 19.0699 3.59826 18.6774L3.67597 18.5902L8.97927 13.2868Z"></path>
</svg>
)}
<span>
{ability === 'vision'
? t('models.visionAbility')
: ability === 'func_call'
? t('models.functionCallAbility')
: ability}
</span>
</div>
))}
</div>
)}
{model.extra_args &&
Object.keys(model.extra_args).length > 0 && (
<div className="text-xs">
<div className="font-semibold mb-1">
{t('models.extraParameters')}
</div>
<div className="space-y-1">
{Object.entries(
model.extra_args as Record<string, unknown>,
).map(([key, value]) => (
<div
key={key}
className="flex items-center gap-1"
>
<span className="text-gray-500">{key}</span>
<span className="break-all">
{JSON.stringify(value)}
</span>
</div>
))}
</div>
</div>
)}
</div>
</HoverCardContent>
</HoverCard>
{Object.entries(groupedModels).map(([providerName, models]) => (
<SelectGroup key={providerName}>
<SelectLabel>{providerName}</SelectLabel>
{models.map((model) => (
<SelectItem key={model.uuid} value={model.uuid}>
{model.name}
{model.abilities?.includes('vision') && ' 👁'}
{model.abilities?.includes('func_call') && ' 🔧'}
</SelectItem>
))}
</SelectGroup>
))}
</SelectContent>
</Select>
);

View File

@@ -5,19 +5,19 @@ import {
Plus,
MessageSquareText,
Cpu,
Info,
RefreshCw,
ChevronLeft,
Cloud,
HardDrive,
Lock,
ChevronDown,
ChevronRight,
Trash2,
Settings,
Sparkles,
LogIn,
} from 'lucide-react';
import { LLMCardVO } from './component/llm-card/LLMCardVO';
import LLMCard from './component/llm-card/LLMCard';
import LLMForm from './component/llm-form/LLMForm';
import { httpClient } from '@/app/infra/http/HttpClient';
import { LLMModel } from '@/app/infra/entities/api';
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs';
import {
LLMModel,
EmbeddingModel,
ModelProvider,
} from '@/app/infra/entities/api';
import {
Dialog,
DialogContent,
@@ -25,68 +25,67 @@ import {
DialogTitle,
} from '@/components/ui/dialog';
import { Button } from '@/components/ui/button';
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from '@/components/ui/dropdown-menu';
import {
Collapsible,
CollapsibleContent,
CollapsibleTrigger,
} from '@/components/ui/collapsible';
import { toast } from 'sonner';
import { useTranslation } from 'react-i18next';
import { extractI18nObject } from '@/i18n/I18nProvider';
import { EmbeddingCardVO } from './component/embedding-card/EmbeddingCardVO';
import EmbeddingCard from './component/embedding-card/EmbeddingCard';
import EmbeddingForm from './component/embedding-form/EmbeddingForm';
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle,
} from '@/components/ui/card';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Badge } from '@/components/ui/badge';
import LLMForm from './component/llm-form/LLMForm';
import EmbeddingForm from './component/embedding-form/EmbeddingForm';
import ProviderForm from './component/provider-form/ProviderForm';
interface ModelsDialogProps {
open: boolean;
onOpenChange: (open: boolean) => void;
}
type ViewMode = 'providers' | 'space' | 'local';
const LANGBOT_MODELS_PROVIDER_NAME = 'LangBot Models';
export default function ModelsDialog({
open,
onOpenChange,
}: ModelsDialogProps) {
const { t } = useTranslation();
const [viewMode, setViewMode] = useState<ViewMode>('providers');
const [activeTab, setActiveTab] = useState<string>('llm');
// User account type
const [providers, setProviders] = useState<ModelProvider[]>([]);
const [accountType, setAccountType] = useState<'local' | 'space'>('local');
const [spaceBalance] = useState<number | null>(null);
// Local models
const [localLLMList, setLocalLLMList] = useState<LLMCardVO[]>([]);
const [localEmbeddingList, setLocalEmbeddingList] = useState<
EmbeddingCardVO[]
>([]);
// Space models
const [spaceLLMList, setSpaceLLMList] = useState<LLMCardVO[]>([]);
const [spaceEmbeddingList, setSpaceEmbeddingList] = useState<
EmbeddingCardVO[]
>([]);
// Sync state
const [isSyncing, setIsSyncing] = useState(false);
// Expanded providers and their models
const [expandedProviders, setExpandedProviders] = useState<Set<string>>(
new Set(),
);
const [providerModels, setProviderModels] = useState<
Record<string, { llm: LLMModel[]; embedding: EmbeddingModel[] }>
>({});
const [loadingProviders, setLoadingProviders] = useState<Set<string>>(
new Set(),
);
// Form modals
const [modalOpen, setModalOpen] = useState<boolean>(false);
const [isEditForm, setIsEditForm] = useState(false);
const [nowSelectedLLM, setNowSelectedLLM] = useState<LLMCardVO | null>(null);
const [embeddingModalOpen, setEmbeddingModalOpen] = useState<boolean>(false);
const [isEditEmbeddingForm, setIsEditEmbeddingForm] = useState(false);
const [nowSelectedEmbedding, setNowSelectedEmbedding] =
useState<EmbeddingCardVO | null>(null);
const [llmFormOpen, setLLMFormOpen] = useState(false);
const [embeddingFormOpen, setEmbeddingFormOpen] = useState(false);
const [providerFormOpen, setProviderFormOpen] = useState(false);
const [editingLLMId, setEditingLLMId] = useState<string | null>(null);
const [editingEmbeddingId, setEditingEmbeddingId] = useState<string | null>(
null,
);
const [editingProviderId, setEditingProviderId] = useState<string | null>(
null,
);
// Requester name lists for display
const [llmRequesterNameList, setLLMRequesterNameList] = useState<
{ label: string; value: string }[]
>([]);
const [embeddingRequesterNameList, setEmbeddingRequesterNameList] = useState<
const [requesterNameList, setRequesterNameList] = useState<
{ label: string; value: string }[]
>([]);
@@ -94,7 +93,7 @@ export default function ModelsDialog({
if (open) {
loadUserInfo();
loadRequesterLists();
loadAllModels();
loadProviders();
}
}, [open]);
@@ -103,7 +102,6 @@ export default function ModelsDialog({
const userInfo = await httpClient.getUserInfo();
setAccountType(userInfo.account_type);
} catch {
// Default to local if user info cannot be fetched
setAccountType('local');
}
}
@@ -111,347 +109,406 @@ export default function ModelsDialog({
async function loadRequesterLists() {
try {
const llmRequesters = await httpClient.getProviderRequesters('llm');
setLLMRequesterNameList(
setRequesterNameList(
llmRequesters.requesters.map((item) => ({
label: extractI18nObject(item.label),
value: item.name,
})),
);
const embeddingRequesters =
await httpClient.getProviderRequesters('text-embedding');
setEmbeddingRequesterNameList(
embeddingRequesters.requesters.map((item) => ({
label: extractI18nObject(item.label),
value: item.name,
})),
);
} catch (err) {
console.error('Failed to load requester lists', err);
}
}
async function loadAllModels() {
await Promise.all([loadLLMModels(), loadEmbeddingModels()]);
}
async function loadLLMModels() {
async function loadProviders() {
try {
const resp = await httpClient.getProviderLLMModels();
const localModels: LLMCardVO[] = [];
const spaceModels: LLMCardVO[] = [];
resp.models.forEach((model: LLMModel & { source?: string }) => {
const cardVO = new LLMCardVO({
id: model.uuid,
iconURL: httpClient.getProviderRequesterIconURL(model.requester),
name: model.name,
providerLabel:
llmRequesterNameList.find((item) => item.value === model.requester)
?.label || model.requester.substring(0, 10),
baseURL: model.requester_config?.base_url,
abilities: model.abilities || [],
});
if (model.source === 'space') {
spaceModels.push(cardVO);
} else {
localModels.push(cardVO);
}
});
setLocalLLMList(localModels);
setSpaceLLMList(spaceModels);
const resp = await httpClient.getModelProviders();
setProviders(resp.providers);
} catch (err) {
console.error('Failed to load LLM models', err);
toast.error(t('models.getModelListError') + (err as Error).message);
console.error('Failed to load providers', err);
toast.error(t('models.loadError'));
}
}
async function loadEmbeddingModels() {
async function loadProviderModels(providerUuid: string) {
if (loadingProviders.has(providerUuid)) return;
setLoadingProviders((prev) => new Set(prev).add(providerUuid));
try {
const resp = await httpClient.getProviderEmbeddingModels();
const localModels: EmbeddingCardVO[] = [];
const spaceModels: EmbeddingCardVO[] = [];
resp.models.forEach(
(model: {
uuid: string;
requester: string;
name: string;
requester_config?: { base_url?: string };
source?: string;
}) => {
const cardVO = new EmbeddingCardVO({
id: model.uuid,
iconURL: httpClient.getProviderRequesterIconURL(model.requester),
name: model.name,
providerLabel:
embeddingRequesterNameList.find(
(item) => item.value === model.requester,
)?.label || model.requester.substring(0, 10),
baseURL: model.requester_config?.base_url || '',
});
if (model.source === 'space') {
spaceModels.push(cardVO);
} else {
localModels.push(cardVO);
}
const [llmResp, embeddingResp] = await Promise.all([
httpClient.getProviderLLMModels(providerUuid),
httpClient.getProviderEmbeddingModels(providerUuid),
]);
setProviderModels((prev) => ({
...prev,
[providerUuid]: {
llm: llmResp.models,
embedding: embeddingResp.models,
},
);
setLocalEmbeddingList(localModels);
setSpaceEmbeddingList(spaceModels);
}));
} catch (err) {
console.error('Failed to load embedding models', err);
toast.error(t('embedding.getModelListError') + (err as Error).message);
}
}
async function handleSyncSpaceModels() {
setIsSyncing(true);
try {
const stats = await httpClient.syncSpaceModels();
toast.success(
t('models.syncSuccess', {
created: stats.created_llm + stats.created_embedding,
updated: stats.updated_llm + stats.updated_embedding,
}),
);
await loadAllModels();
} catch (err) {
toast.error(t('models.syncError') + (err as Error).message);
console.error('Failed to load models', err);
} finally {
setIsSyncing(false);
setLoadingProviders((prev) => {
const next = new Set(prev);
next.delete(providerUuid);
return next;
});
}
}
function selectLLM(cardVO: LLMCardVO, isSpaceModel: boolean) {
if (isSpaceModel) {
// Space models are read-only, just show info
toast.info(t('models.spaceModelReadOnly'));
return;
function toggleProvider(providerUuid: string) {
setExpandedProviders((prev) => {
const next = new Set(prev);
if (next.has(providerUuid)) {
next.delete(providerUuid);
} else {
next.add(providerUuid);
if (!providerModels[providerUuid]) {
loadProviderModels(providerUuid);
}
setIsEditForm(true);
setNowSelectedLLM(cardVO);
setModalOpen(true);
}
return next;
});
}
function handleCreateModelClick() {
setIsEditForm(false);
setNowSelectedLLM(null);
setModalOpen(true);
function handleCreateLLM() {
setEditingLLMId(null);
setLLMFormOpen(true);
}
function selectEmbedding(cardVO: EmbeddingCardVO, isSpaceModel: boolean) {
if (isSpaceModel) {
toast.info(t('models.spaceModelReadOnly'));
return;
}
setIsEditEmbeddingForm(true);
setNowSelectedEmbedding(cardVO);
setEmbeddingModalOpen(true);
function handleCreateEmbedding() {
setEditingEmbeddingId(null);
setEmbeddingFormOpen(true);
}
function handleCreateEmbeddingModelClick() {
setIsEditEmbeddingForm(false);
setNowSelectedEmbedding(null);
setEmbeddingModalOpen(true);
function handleEditLLM(modelId: string) {
setEditingLLMId(modelId);
setLLMFormOpen(true);
}
function renderProviderCards() {
const isSpaceDisabled = accountType === 'local';
function handleEditEmbedding(modelId: string) {
setEditingEmbeddingId(modelId);
setEmbeddingFormOpen(true);
}
function handleEditProvider(providerId: string) {
setEditingProviderId(providerId);
setProviderFormOpen(true);
}
async function handleDeleteProvider(providerId: string) {
try {
await httpClient.deleteModelProvider(providerId);
toast.success(t('models.providerDeleted'));
loadProviders();
} catch (err) {
toast.error(t('models.providerDeleteError') + (err as Error).message);
}
}
async function handleDeleteLLM(modelId: string, providerUuid: string) {
try {
await httpClient.deleteProviderLLMModel(modelId);
toast.success(t('models.deleteSuccess'));
loadProviderModels(providerUuid);
loadProviders(); // Refresh counts
} catch (err) {
toast.error(t('models.deleteError') + (err as Error).message);
}
}
async function handleDeleteEmbedding(modelId: string, providerUuid: string) {
try {
await httpClient.deleteProviderEmbeddingModel(modelId);
toast.success(t('models.deleteSuccess'));
loadProviderModels(providerUuid);
loadProviders();
} catch (err) {
toast.error(t('models.deleteError') + (err as Error).message);
}
}
function handleSpaceLogin() {
window.location.href = '/auth/space';
}
function getRequesterLabel(requester: string) {
return (
<div className="grid grid-cols-1 md:grid-cols-2 gap-6 p-4">
{/* Space Provider Card */}
<Card
className={`cursor-pointer transition-all hover:shadow-lg ${
isSpaceDisabled ? 'opacity-50 cursor-not-allowed' : ''
}`}
onClick={() => !isSpaceDisabled && setViewMode('space')}
>
<CardHeader className="flex flex-row items-center gap-4">
<div className="p-3 bg-blue-100 dark:bg-blue-900 rounded-lg">
<Cloud className="h-8 w-8 text-blue-600 dark:text-blue-400" />
</div>
<div className="flex-1">
<div className="flex items-center gap-2">
<CardTitle>Space</CardTitle>
{isSpaceDisabled && (
<Lock className="h-4 w-4 text-muted-foreground" />
)}
</div>
<CardDescription>
{isSpaceDisabled
? t('models.spaceDisabledForLocalAccount')
: t('models.spaceProviderDescription')}
</CardDescription>
</div>
</CardHeader>
<CardContent>
<div className="flex items-center gap-4 text-sm text-muted-foreground">
<Badge variant="secondary">{spaceLLMList.length} LLM</Badge>
<Badge variant="secondary">
{spaceEmbeddingList.length} Embedding
</Badge>
</div>
</CardContent>
</Card>
{/* Local Provider Card */}
<Card
className="cursor-pointer transition-all hover:shadow-lg"
onClick={() => setViewMode('local')}
>
<CardHeader className="flex flex-row items-center gap-4">
<div className="p-3 bg-green-100 dark:bg-green-900 rounded-lg">
<HardDrive className="h-8 w-8 text-green-600 dark:text-green-400" />
</div>
<div className="flex-1">
<CardTitle>{t('models.localProvider')}</CardTitle>
<CardDescription>
{t('models.localProviderDescription')}
</CardDescription>
</div>
</CardHeader>
<CardContent>
<div className="flex items-center gap-4 text-sm text-muted-foreground">
<Badge variant="secondary">{localLLMList.length} LLM</Badge>
<Badge variant="secondary">
{localEmbeddingList.length} Embedding
</Badge>
</div>
</CardContent>
</Card>
</div>
requesterNameList.find((r) => r.value === requester)?.label || requester
);
}
function renderModelList(
llmList: LLMCardVO[],
embeddingList: EmbeddingCardVO[],
isSpaceModel: boolean = false,
function maskApiKey(key: string): string {
if (!key) return '';
if (key.length <= 8) return '****';
return `${key.slice(0, 4)}...${key.slice(-4)}`;
}
// Separate LangBot Models provider
const langbotProvider = providers.find(
(p) => p.name === LANGBOT_MODELS_PROVIDER_NAME,
);
const otherProviders = providers.filter(
(p) => p.name !== LANGBOT_MODELS_PROVIDER_NAME,
);
function renderProviderCard(
provider: ModelProvider,
isLangBotModels: boolean = false,
) {
const isExpanded = expandedProviders.has(provider.uuid);
const isLoading = loadingProviders.has(provider.uuid);
const models = providerModels[provider.uuid];
const canDelete =
!isLangBotModels &&
(provider.llm_count || 0) === 0 &&
(provider.embedding_count || 0) === 0;
const totalModels =
(provider.llm_count || 0) + (provider.embedding_count || 0);
return (
<Tabs
value={activeTab}
onValueChange={setActiveTab}
className="w-full flex-1 flex flex-col overflow-hidden"
<Card key={provider.uuid} className="mb-2">
<Collapsible
open={isExpanded}
onOpenChange={() => toggleProvider(provider.uuid)}
>
<div className="flex flex-row justify-between items-center mb-2">
<TabsList className="shadow-md py-5 bg-[#f0f0f0] dark:bg-[#2a2a2e]">
<TabsTrigger value="llm" className="px-6 py-4 cursor-pointer">
<MessageSquareText className="h-4 w-4 mr-1.5" />
{t('llm.llmModels')}
</TabsTrigger>
<TabsTrigger value="embedding" className="px-6 py-4 cursor-pointer">
<Cpu className="h-4 w-4 mr-1.5" />
{t('embedding.embeddingModels')}
</TabsTrigger>
</TabsList>
<div className="flex gap-2">
{isSpaceModel ? (
<Button
size="sm"
variant="outline"
onClick={handleSyncSpaceModels}
disabled={isSyncing}
>
<RefreshCw
className={`h-4 w-4 mr-1 ${isSyncing ? 'animate-spin' : ''}`}
<CardHeader className="px-4 pb-2">
<div className="flex items-center justify-between">
<div className="flex items-center gap-2 flex-1">
{isLangBotModels ? (
<div className="p-2 bg-gradient-to-br from-purple-500 to-blue-500 rounded-lg">
<Sparkles className="h-5 w-5 text-white" />
</div>
) : (
<img
src={httpClient.getProviderRequesterIconURL(
provider.requester,
)}
alt={provider.name}
className="h-9 w-9 rounded-lg"
/>
{t('models.syncModels')}
</Button>
)}
<div className="flex-1 min-w-0">
<div className="flex items-center gap-2">
<CardTitle className="text-base">
{isLangBotModels
? provider.name
: getRequesterLabel(provider.requester)}
</CardTitle>
<Badge variant="outline" className="text-xs">
{t('models.modelsCount', { count: totalModels })}
</Badge>
</div>
<p className="text-xs text-muted-foreground truncate">
{isLangBotModels ? (
t('models.langbotModelsDescription')
) : (
<>
{provider.base_url}
{provider.base_url &&
provider.api_keys?.length > 0 &&
' · '}
{provider.api_keys?.length > 0 &&
maskApiKey(provider.api_keys[0])}
</>
)}
</p>
</div>
</div>
<div className="flex items-center gap-1 ml-2">
{isLangBotModels && accountType !== 'space' && (
<Button
variant="outline"
size="sm"
onClick={
activeTab === 'llm'
? handleCreateModelClick
: handleCreateEmbeddingModelClick
}
onClick={(e) => {
e.stopPropagation();
handleSpaceLogin();
}}
>
<Plus className="h-4 w-4 mr-1" />
{activeTab === 'llm'
? t('models.createModel')
: t('embedding.createModel')}
<LogIn className="h-4 w-4 mr-1" />
{t('models.loginWithSpace')}
</Button>
)}
{isLangBotModels && accountType === 'space' && (
<Badge variant="secondary">
{t('models.balance')}: {spaceBalance ?? '--'}
</Badge>
)}
{!isLangBotModels && (
<>
<Button
variant="ghost"
size="icon"
className="h-8 w-8"
onClick={(e) => {
e.stopPropagation();
handleEditProvider(provider.uuid);
}}
>
<Settings className="h-4 w-4" />
</Button>
{canDelete && (
<Button
variant="ghost"
size="icon"
className="h-8 w-8"
onClick={(e) => {
e.stopPropagation();
handleDeleteProvider(provider.uuid);
}}
>
<Trash2 className="h-4 w-4 text-destructive" />
</Button>
)}
</>
)}
</div>
</div>
<div className="mb-3 flex items-center">
<Info className="h-4 w-4 mr-1.5 text-muted-foreground" />
{activeTab === 'llm' ? (
<p className="text-sm text-muted-foreground flex items-center">
{t('llm.description')}
</p>
<CollapsibleTrigger className="flex items-center gap-1 text-xs text-muted-foreground hover:text-foreground cursor-pointer mt-2">
{isExpanded ? (
<ChevronDown className="h-3 w-3" />
) : (
<p className="text-sm text-muted-foreground flex items-center">
{t('embedding.description')}
<ChevronRight className="h-3 w-3" />
)}
<span>
{isExpanded
? t('models.collapseModels')
: t('models.expandModels')}
</span>
</CollapsibleTrigger>
</CardHeader>
<CollapsibleContent>
<CardContent className="px-4">
{isLoading ? (
<p className="text-sm text-muted-foreground text-center py-4">
{t('common.loading')}...
</p>
) : models ? (
<div className="space-y-2">
{models.llm.map((model) => (
<div
key={model.uuid}
className="flex items-center justify-between py-2 px-3 rounded-md border bg-background hover:bg-accent cursor-pointer"
onClick={() => handleEditLLM(model.uuid)}
>
<div className="flex items-center gap-2 flex-wrap">
<span className="text-sm font-medium">
{model.name}
</span>
<Badge variant="secondary" className="text-xs">
{t('models.chat')}
</Badge>
{model.abilities?.includes('vision') && (
<Badge variant="outline" className="text-xs">
👁
</Badge>
)}
{model.abilities?.includes('func_call') && (
<Badge variant="outline" className="text-xs">
🔧
</Badge>
)}
</div>
<Button
variant="ghost"
size="icon"
className="h-7 w-7 flex-shrink-0"
onClick={(e) => {
e.stopPropagation();
handleDeleteLLM(model.uuid, provider.uuid);
}}
>
<Trash2 className="h-4 w-4 text-muted-foreground hover:text-destructive" />
</Button>
</div>
))}
{models.embedding.map((model) => (
<div
key={model.uuid}
className="flex items-center justify-between py-2 px-3 rounded-md border bg-background hover:bg-accent cursor-pointer"
onClick={() => handleEditEmbedding(model.uuid)}
>
<div className="flex items-center gap-2">
<span className="text-sm font-medium">
{model.name}
</span>
<Badge variant="secondary" className="text-xs">
{t('models.embedding')}
</Badge>
</div>
<Button
variant="ghost"
size="icon"
className="h-7 w-7 flex-shrink-0"
onClick={(e) => {
e.stopPropagation();
handleDeleteEmbedding(model.uuid, provider.uuid);
}}
>
<Trash2 className="h-4 w-4 text-muted-foreground hover:text-destructive" />
</Button>
</div>
))}
{models.llm.length === 0 && models.embedding.length === 0 && (
<p className="text-sm text-muted-foreground text-center py-4">
{t('models.noModels')}
</p>
)}
</div>
<TabsContent value="llm" className="flex-1 overflow-auto mt-0">
{llmList.length === 0 ? (
<div className="flex items-center justify-center h-32 text-muted-foreground">
{isSpaceModel
? t('models.noSpaceModels')
: t('models.noLocalModels')}
</div>
) : (
<div className="w-full grid grid-cols-[repeat(auto-fill,minmax(20rem,1fr))] gap-4">
{llmList.map((cardVO) => (
<div
key={cardVO.id}
onClick={() => selectLLM(cardVO, isSpaceModel)}
className={isSpaceModel ? 'cursor-default' : 'cursor-pointer'}
>
<LLMCard cardVO={cardVO} />
</div>
))}
</div>
<p className="text-sm text-muted-foreground text-center py-4">
{t('models.noModels')}
</p>
)}
</TabsContent>
<TabsContent value="embedding" className="flex-1 overflow-auto mt-0">
{embeddingList.length === 0 ? (
<div className="flex items-center justify-center h-32 text-muted-foreground">
{isSpaceModel
? t('models.noSpaceModels')
: t('models.noLocalModels')}
</div>
) : (
<div className="w-full grid grid-cols-[repeat(auto-fill,minmax(20rem,1fr))] gap-4">
{embeddingList.map((cardVO) => (
<div
key={cardVO.id}
onClick={() => selectEmbedding(cardVO, isSpaceModel)}
className={isSpaceModel ? 'cursor-default' : 'cursor-pointer'}
>
<EmbeddingCard cardVO={cardVO} />
</div>
))}
</div>
)}
</TabsContent>
</Tabs>
</CardContent>
</CollapsibleContent>
</Collapsible>
</Card>
);
}
function getDialogTitle() {
switch (viewMode) {
case 'space':
return 'Space ' + t('models.title');
case 'local':
return t('models.localProvider') + ' ' + t('models.title');
default:
return t('models.title');
// Virtual LangBot Models card if not exists
function renderLangBotModelsCard() {
if (langbotProvider) {
return renderProviderCard(langbotProvider, true);
}
return (
<Card className="mb-2">
<CardHeader className="p-3">
<div className="flex items-center justify-between">
<div className="flex items-center gap-2">
<div className="p-2 bg-gradient-to-br from-purple-500 to-blue-500 rounded-lg">
<Sparkles className="h-5 w-5 text-white" />
</div>
<div>
<CardTitle className="text-base">
{LANGBOT_MODELS_PROVIDER_NAME}
</CardTitle>
<p className="text-xs text-muted-foreground">
{t('models.langbotModelsDescription')}
</p>
</div>
</div>
{accountType !== 'space' && (
<Button variant="outline" size="sm" onClick={handleSpaceLogin}>
<LogIn className="h-4 w-4 mr-1" />
{t('models.loginWithSpace')}
</Button>
)}
</div>
</CardHeader>
</Card>
);
}
function handleFormClose() {
setLLMFormOpen(false);
setEmbeddingFormOpen(false);
setProviderFormOpen(false);
loadProviders();
// Refresh expanded providers
expandedProviders.forEach((uuid) => loadProviderModels(uuid));
}
return (
@@ -459,89 +516,101 @@ export default function ModelsDialog({
<Dialog
open={open}
onOpenChange={(newOpen) => {
if (!newOpen && (modalOpen || embeddingModalOpen)) {
if (
!newOpen &&
(llmFormOpen || embeddingFormOpen || providerFormOpen)
)
return;
}
if (!newOpen) {
setViewMode('providers');
}
onOpenChange(newOpen);
}}
>
<DialogContent className="overflow-hidden p-0 !max-w-[80vw] h-[75vh] flex flex-col">
<DialogContent className="overflow-hidden p-0 h-[80vh] flex flex-col">
<DialogHeader className="px-6 pt-6 pb-0">
<div className="flex items-center gap-2">
{viewMode !== 'providers' && (
<Button
variant="ghost"
size="sm"
onClick={() => setViewMode('providers')}
>
<ChevronLeft className="h-4 w-4" />
</Button>
)}
<DialogTitle>{getDialogTitle()}</DialogTitle>
</div>
<DialogTitle>{t('models.title')}</DialogTitle>
</DialogHeader>
<div className="flex-1 overflow-auto px-6 pb-6 mt-4">
{viewMode === 'providers' && renderProviderCards()}
{viewMode === 'space' &&
renderModelList(spaceLLMList, spaceEmbeddingList, true)}
{viewMode === 'local' &&
renderModelList(localLLMList, localEmbeddingList, false)}
<div className="flex-1 flex flex-col overflow-hidden px-6 pb-6 mt-4">
{/* Fixed LangBot Models Card */}
<div className="flex-shrink-0">{renderLangBotModelsCard()}</div>
{/* Add Model Button */}
<div className="flex-shrink-0 mb-3 flex justify-end">
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button size="sm" variant="outline">
<Plus className="h-4 w-4 mr-1" />
{t('models.addModel')}
<ChevronDown className="h-4 w-4 ml-1" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem onClick={handleCreateLLM}>
<MessageSquareText className="h-4 w-4 mr-2" />
{t('models.addLLMModel')}
</DropdownMenuItem>
<DropdownMenuItem onClick={handleCreateEmbedding}>
<Cpu className="h-4 w-4 mr-2" />
{t('models.addEmbeddingModel')}
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</div>
{/* Scrollable Provider List */}
<div className="flex-1 overflow-auto">
{otherProviders.map((p) => renderProviderCard(p))}
</div>
</div>
</DialogContent>
</Dialog>
<Dialog open={modalOpen} onOpenChange={setModalOpen}>
<DialogContent className="w-[700px] p-6">
<Dialog open={llmFormOpen} onOpenChange={setLLMFormOpen}>
<DialogContent className="w-[700px] max-h-[90vh] overflow-y-auto p-6">
<DialogHeader>
<DialogTitle>
{isEditForm ? t('models.editModel') : t('models.createModel')}
{editingLLMId ? t('models.editModel') : t('models.createModel')}
</DialogTitle>
</DialogHeader>
<LLMForm
editMode={isEditForm}
initLLMId={nowSelectedLLM?.id}
onFormSubmit={() => {
setModalOpen(false);
loadAllModels();
}}
onFormCancel={() => {
setModalOpen(false);
}}
onLLMDeleted={() => {
setModalOpen(false);
loadAllModels();
}}
editMode={!!editingLLMId}
initLLMId={editingLLMId || undefined}
providers={providers}
onFormSubmit={handleFormClose}
onFormCancel={() => setLLMFormOpen(false)}
onLLMDeleted={handleFormClose}
/>
</DialogContent>
</Dialog>
<Dialog open={embeddingModalOpen} onOpenChange={setEmbeddingModalOpen}>
<DialogContent className="w-[700px] p-6">
<Dialog open={embeddingFormOpen} onOpenChange={setEmbeddingFormOpen}>
<DialogContent className="w-[700px] max-h-[90vh] overflow-y-auto p-6">
<DialogHeader>
<DialogTitle>
{isEditEmbeddingForm
{editingEmbeddingId
? t('embedding.editModel')
: t('embedding.createModel')}
</DialogTitle>
</DialogHeader>
<EmbeddingForm
editMode={isEditEmbeddingForm}
initEmbeddingId={nowSelectedEmbedding?.id}
onFormSubmit={() => {
setEmbeddingModalOpen(false);
loadAllModels();
}}
onFormCancel={() => {
setEmbeddingModalOpen(false);
}}
onEmbeddingDeleted={() => {
setEmbeddingModalOpen(false);
loadAllModels();
}}
editMode={!!editingEmbeddingId}
initEmbeddingId={editingEmbeddingId || undefined}
providers={providers}
onFormSubmit={handleFormClose}
onFormCancel={() => setEmbeddingFormOpen(false)}
onEmbeddingDeleted={handleFormClose}
/>
</DialogContent>
</Dialog>
<Dialog open={providerFormOpen} onOpenChange={setProviderFormOpen}>
<DialogContent className="w-[600px] p-6">
<DialogHeader>
<DialogTitle>{t('models.editProvider')}</DialogTitle>
</DialogHeader>
<ProviderForm
providerId={editingProviderId || undefined}
onFormSubmit={handleFormClose}
onFormCancel={() => setProviderFormOpen(false)}
/>
</DialogContent>
</Dialog>

View File

@@ -1,9 +1,6 @@
import { ICreateEmbeddingField } from '../ICreateEmbeddingField';
import { useEffect, useState } from 'react';
import { IChooseRequesterEntity } from '../ChooseRequesterEntity';
import { httpClient } from '@/app/infra/http/HttpClient';
import { EmbeddingModel } from '@/app/infra/entities/api';
import { UUID } from 'uuidjs';
import { ModelProvider } from '@/app/infra/entities/api';
import { zodResolver } from '@hookform/resolvers/zod';
import { useForm } from 'react-hook-form';
@@ -42,59 +39,43 @@ import { toast } from 'sonner';
import { extractI18nObject } from '@/i18n/I18nProvider';
import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert';
import { AlertCircle } from 'lucide-react';
const getExtraArgSchema = (t: (key: string) => string) =>
z
.object({
key: z.string().min(1, { message: t('models.keyNameRequired') }),
type: z.enum(['string', 'number', 'boolean']),
value: z.string(),
})
.superRefine((data, ctx) => {
if (data.type === 'number' && isNaN(Number(data.value))) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
message: t('models.mustBeValidNumber'),
path: ['value'],
});
}
if (
data.type === 'boolean' &&
data.value !== 'true' &&
data.value !== 'false'
) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
message: t('models.mustBeTrueOrFalse'),
path: ['value'],
});
}
});
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs';
const getFormSchema = (t: (key: string) => string) =>
z.object({
name: z.string().min(1, { message: t('models.modelNameRequired') }),
model_provider: z
.string()
.min(1, { message: t('models.modelProviderRequired') }),
url: z.string().optional(),
api_key: z.string().optional(),
extra_args: z.array(getExtraArgSchema(t)).optional(),
provider_uuid: z.string().optional(),
new_provider_requester: z.string().optional(),
new_provider_url: z.string().optional(),
new_provider_api_key: z.string().optional(),
extra_args: z
.array(
z.object({
key: z.string(),
type: z.enum(['string', 'number', 'boolean']),
value: z.string(),
}),
)
.optional(),
});
interface EmbeddingFormProps {
editMode: boolean;
initEmbeddingId?: string;
providers: ModelProvider[];
onFormSubmit: () => void;
onFormCancel: () => void;
onEmbeddingDeleted: () => void;
}
export default function EmbeddingForm({
editMode,
initEmbeddingId,
providers,
onFormSubmit,
onFormCancel,
onEmbeddingDeleted,
}: {
editMode: boolean;
initEmbeddingId?: string;
onFormSubmit: () => void;
onFormCancel: () => void;
onEmbeddingDeleted: () => void;
}) {
}: EmbeddingFormProps) {
const { t } = useTranslation();
const formSchema = getFormSchema(t);
@@ -102,9 +83,10 @@ export default function EmbeddingForm({
resolver: zodResolver(formSchema),
defaultValues: {
name: '',
model_provider: '',
url: '',
api_key: '',
provider_uuid: '',
new_provider_requester: '',
new_provider_url: '',
new_provider_api_key: '',
extra_args: [],
},
});
@@ -112,54 +94,178 @@ export default function EmbeddingForm({
const [extraArgs, setExtraArgs] = useState<
{ key: string; type: 'string' | 'number' | 'boolean'; value: string }[]
>([]);
const [showDeleteConfirmModal, setShowDeleteConfirmModal] = useState(false);
const [requesterNameList, setRequesterNameList] = useState<
IChooseRequesterEntity[]
>([]);
const [requesterDefaultURLList, setRequesterDefaultURLList] = useState<
string[]
>([]);
const [modelTesting, setModelTesting] = useState(false);
const [testErrorMessage, setTestErrorMessage] = useState<string | null>(null);
const [currentModelProvider, setCurrentModelProvider] = useState('');
const [providerMode, setProviderMode] = useState<'existing' | 'new'>(
'existing',
);
const [requesterList, setRequesterList] = useState<
{ label: string; value: string; category: string; defaultUrl: string }[]
>([]);
useEffect(() => {
initEmbeddingModelFormComponent().then(() => {
loadRequesters();
if (editMode && initEmbeddingId) {
getEmbeddingConfig(initEmbeddingId).then((val) => {
form.setValue('name', val.name);
form.setValue('model_provider', val.model_provider);
setCurrentModelProvider(val.model_provider);
form.setValue('url', val.url);
form.setValue('api_key', val.api_key);
if (val.extra_args) {
const args = val.extra_args.map((arg) => {
const [key, value] = arg.split(':');
let type: 'string' | 'number' | 'boolean' = 'string';
if (!isNaN(Number(value))) {
type = 'number';
} else if (value === 'true' || value === 'false') {
type = 'boolean';
loadModel(initEmbeddingId);
}
return {
key,
type,
value,
};
}, [editMode, initEmbeddingId]);
async function loadRequesters() {
const resp = await httpClient.getProviderRequesters('text-embedding');
setRequesterList(
resp.requesters.map((item) => ({
label: extractI18nObject(item.label),
value: item.name,
category: item.spec.provider_category || 'manufacturer',
defaultUrl:
item.spec.config
.find((c) => c.name === 'base_url')
?.default?.toString() || '',
})),
);
}
async function loadModel(id: string) {
const resp = await httpClient.getProviderEmbeddingModel(id);
const model = resp.model;
form.setValue('name', model.name);
form.setValue('provider_uuid', model.provider_uuid);
if (model.extra_args) {
const args = Object.entries(model.extra_args).map(([key, value]) => {
let type: 'string' | 'number' | 'boolean' = 'string';
if (typeof value === 'number') type = 'number';
else if (typeof value === 'boolean') type = 'boolean';
return { key, type, value: String(value) };
});
setExtraArgs(args);
form.setValue('extra_args', args);
}
});
} else {
form.reset();
setProviderMode('existing');
}
function handleFormSubmit(values: z.infer<typeof formSchema>) {
const extraArgsObj: Record<string, string | number | boolean> = {};
values.extra_args?.forEach((arg) => {
if (arg.type === 'number') extraArgsObj[arg.key] = Number(arg.value);
else if (arg.type === 'boolean')
extraArgsObj[arg.key] = arg.value === 'true';
else extraArgsObj[arg.key] = arg.value;
});
}, []);
const modelData: Record<string, unknown> = {
name: values.name,
extra_args: extraArgsObj,
};
if (providerMode === 'existing' && values.provider_uuid) {
modelData.provider_uuid = values.provider_uuid;
} else if (providerMode === 'new') {
modelData.provider = {
requester: values.new_provider_requester,
base_url: values.new_provider_url,
api_keys: values.new_provider_api_key
? [values.new_provider_api_key]
: [],
};
}
if (editMode && initEmbeddingId) {
updateModel(initEmbeddingId, modelData);
} else {
createModel(modelData);
}
}
async function createModel(data: Record<string, unknown>) {
try {
await httpClient.createProviderEmbeddingModel(data as never);
toast.success(t('models.createSuccess'));
onFormSubmit();
} catch (err) {
toast.error(t('models.createError') + (err as Error).message);
}
}
async function updateModel(id: string, data: Record<string, unknown>) {
try {
await httpClient.updateProviderEmbeddingModel(id, data as never);
toast.success(t('models.saveSuccess'));
onFormSubmit();
} catch (err) {
toast.error(t('models.saveError') + (err as Error).message);
}
}
async function deleteModel() {
if (!initEmbeddingId) return;
try {
await httpClient.deleteProviderEmbeddingModel(initEmbeddingId);
toast.success(t('models.deleteSuccess'));
onEmbeddingDeleted();
} catch (err) {
toast.error(t('models.deleteError') + (err as Error).message);
}
}
async function testModel() {
setModelTesting(true);
setTestErrorMessage(null);
const values = form.getValues();
const extraArgsObj: Record<string, string | number | boolean> = {};
values.extra_args?.forEach((arg) => {
if (arg.type === 'number') extraArgsObj[arg.key] = Number(arg.value);
else if (arg.type === 'boolean')
extraArgsObj[arg.key] = arg.value === 'true';
else extraArgsObj[arg.key] = arg.value;
});
let provider: Record<string, unknown>;
if (providerMode === 'existing' && values.provider_uuid) {
const p = providers.find((p) => p.uuid === values.provider_uuid);
provider = {
requester: p?.requester || '',
base_url: p?.base_url || '',
api_keys: p?.api_keys || [],
};
} else {
provider = {
requester: values.new_provider_requester,
base_url: values.new_provider_url,
api_keys: values.new_provider_api_key
? [values.new_provider_api_key]
: [],
};
}
try {
await httpClient.testEmbeddingModel('_', {
uuid: '',
name: values.name,
provider_uuid: '',
provider,
extra_args: extraArgsObj,
} as never);
toast.success(t('models.testSuccess'));
} catch (err) {
setTestErrorMessage((err as Error).message || t('models.testError'));
} finally {
setModelTesting(false);
}
}
const addExtraArg = () => {
setExtraArgs([...extraArgs, { key: '', type: 'string', value: '' }]);
const newArgs = [
...extraArgs,
{ key: '', type: 'string' as const, value: '' },
];
setExtraArgs(newArgs);
form.setValue('extra_args', newArgs);
};
const updateExtraArg = (
@@ -168,10 +274,7 @@ export default function EmbeddingForm({
value: string,
) => {
const newArgs = [...extraArgs];
newArgs[index] = {
...newArgs[index],
[field]: value,
};
newArgs[index] = { ...newArgs[index], [field]: value };
setExtraArgs(newArgs);
form.setValue('extra_args', newArgs);
};
@@ -182,167 +285,6 @@ export default function EmbeddingForm({
form.setValue('extra_args', newArgs);
};
async function initEmbeddingModelFormComponent() {
const requesterNameList =
await httpClient.getProviderRequesters('text-embedding');
setRequesterNameList(
requesterNameList.requesters.map((item) => {
return {
label: extractI18nObject(item.label),
value: item.name,
provider_category: item.spec.provider_category || 'manufacturer',
description: extractI18nObject(item.description) || undefined,
};
}),
);
setRequesterDefaultURLList(
requesterNameList.requesters.map((item) => {
const config = item.spec.config;
for (let i = 0; i < config.length; i++) {
if (config[i].name == 'base_url') {
return config[i].default?.toString() || '';
}
}
return '';
}),
);
}
async function getEmbeddingConfig(
id: string,
): Promise<ICreateEmbeddingField> {
const embeddingModel = await httpClient.getProviderEmbeddingModel(id);
const fakeExtraArgs = [];
const extraArgs = embeddingModel.model.extra_args as Record<string, string>;
for (const key in extraArgs) {
fakeExtraArgs.push(`${key}:${extraArgs[key]}`);
}
return {
name: embeddingModel.model.name,
model_provider: embeddingModel.model.requester,
url: embeddingModel.model.requester_config?.base_url,
api_key: embeddingModel.model.api_keys[0],
extra_args: fakeExtraArgs,
};
}
function handleFormSubmit(value: z.infer<typeof formSchema>) {
const extraArgsObj: Record<string, string | number | boolean> = {};
value.extra_args?.forEach(
(arg: { key: string; type: string; value: string }) => {
if (arg.type === 'number') {
extraArgsObj[arg.key] = Number(arg.value);
} else if (arg.type === 'boolean') {
extraArgsObj[arg.key] = arg.value === 'true';
} else {
extraArgsObj[arg.key] = arg.value;
}
},
);
const embeddingModel: EmbeddingModel = {
uuid: editMode ? initEmbeddingId || '' : UUID.generate(),
name: value.name,
description: '',
requester: value.model_provider,
requester_config: {
base_url: value.url || '',
timeout: 120,
},
extra_args: extraArgsObj,
api_keys: value.api_key ? [value.api_key] : [],
};
if (editMode) {
onSaveEdit(embeddingModel).then(() => {
form.reset();
});
} else {
onCreateEmbedding(embeddingModel).then(() => {
form.reset();
});
}
}
async function onCreateEmbedding(embeddingModel: EmbeddingModel) {
try {
await httpClient.createProviderEmbeddingModel(embeddingModel);
onFormSubmit();
toast.success(t('models.createSuccess'));
} catch (err) {
toast.error(t('models.createError') + (err as Error).message);
}
}
async function onSaveEdit(embeddingModel: EmbeddingModel) {
try {
await httpClient.updateProviderEmbeddingModel(
initEmbeddingId || '',
embeddingModel,
);
onFormSubmit();
toast.success(t('models.saveSuccess'));
} catch (err) {
toast.error(t('models.saveError') + (err as Error).message);
}
}
function deleteModel() {
if (initEmbeddingId) {
httpClient
.deleteProviderEmbeddingModel(initEmbeddingId)
.then(() => {
onEmbeddingDeleted();
toast.success(t('models.deleteSuccess'));
})
.catch((err) => {
toast.error(t('models.deleteError') + err.message);
});
}
}
function testEmbeddingModelInForm() {
setModelTesting(true);
setTestErrorMessage(null);
const extraArgsObj: Record<string, string | number | boolean> = {};
form
.getValues('extra_args')
?.forEach((arg: { key: string; type: string; value: string }) => {
if (arg.type === 'number') {
extraArgsObj[arg.key] = Number(arg.value);
} else if (arg.type === 'boolean') {
extraArgsObj[arg.key] = arg.value === 'true';
} else {
extraArgsObj[arg.key] = arg.value;
}
});
const apiKey = form.getValues('api_key');
httpClient
.testEmbeddingModel('_', {
uuid: '',
name: form.getValues('name'),
description: '',
requester: form.getValues('model_provider'),
requester_config: {
base_url: form.getValues('url') ?? '',
timeout: 120,
},
api_keys: apiKey ? [apiKey] : [],
extra_args: extraArgsObj,
})
.then(() => {
toast.success(t('models.testSuccess'));
setTestErrorMessage(null);
})
.catch((err: { message?: string }) => {
setTestErrorMessage(err?.message || t('models.testError'));
})
.finally(() => {
setModelTesting(false);
});
}
return (
<div>
<Dialog
@@ -379,9 +321,8 @@ export default function EmbeddingForm({
<Form {...form}>
<form
onSubmit={form.handleSubmit(handleFormSubmit)}
className="space-y-8"
className="space-y-6"
>
<div className="space-y-4">
<FormField
control={form.control}
name="name"
@@ -392,69 +333,90 @@ export default function EmbeddingForm({
<span className="text-red-500">*</span>
</FormLabel>
<FormControl>
<Input {...field} />
<Input {...field} placeholder="text-embedding-3-small" />
</FormControl>
<FormMessage />
<FormDescription>
{t('models.modelProviderDescription')}
</FormDescription>
<FormMessage />
</FormItem>
)}
/>
<div>
<FormLabel>{t('models.provider')}</FormLabel>
<Tabs
value={providerMode}
onValueChange={(v) => setProviderMode(v as 'existing' | 'new')}
className="mt-2"
>
<TabsList>
<TabsTrigger value="existing">
{t('models.existingProvider')}
</TabsTrigger>
<TabsTrigger value="new">{t('models.newProvider')}</TabsTrigger>
</TabsList>
<TabsContent value="existing" className="mt-3">
<FormField
control={form.control}
name="model_provider"
name="provider_uuid"
render={({ field }) => (
<FormItem>
<FormLabel>
{t('models.modelProvider')}
<span className="text-red-500">*</span>
</FormLabel>
<FormControl>
<Select
onValueChange={(value) => {
field.onChange(value);
setCurrentModelProvider(value);
const index = requesterNameList.findIndex(
(item) => item.value === value,
);
if (index !== -1) {
form.setValue('url', requesterDefaultURLList[index]);
}
onValueChange={field.onChange}
value={field.value}
>
<SelectTrigger className="bg-background">
<SelectValue
placeholder={t('models.selectProvider')}
/>
</SelectTrigger>
<SelectContent>
{providers.map((p) => (
<SelectItem key={p.uuid} value={p.uuid}>
{p.name} ({p.base_url || 'default'})
</SelectItem>
))}
</SelectContent>
</Select>
<FormMessage />
</FormItem>
)}
/>
</TabsContent>
<TabsContent value="new" className="mt-3 space-y-4">
<FormField
control={form.control}
name="new_provider_requester"
render={({ field }) => (
<FormItem>
<FormLabel>{t('models.requester')}</FormLabel>
<Select
onValueChange={(v) => {
field.onChange(v);
const req = requesterList.find((r) => r.value === v);
if (req)
form.setValue('new_provider_url', req.defaultUrl);
}}
value={field.value}
>
<SelectTrigger className="w-[180px] bg-[#ffffff] dark:bg-[#2a2a2e]">
<SelectTrigger className="bg-background">
<SelectValue
placeholder={t('models.selectModelProvider')}
placeholder={t('models.selectRequester')}
/>
</SelectTrigger>
<SelectContent>
<SelectGroup>
<SelectLabel>{t('models.builtin')}</SelectLabel>
{requesterNameList
.filter(
(item) => item.provider_category === 'builtin',
)
.map((item) => (
<SelectItem key={item.value} value={item.value}>
{item.label}
</SelectItem>
))}
</SelectGroup>
<SelectGroup>
<SelectLabel>
{t('models.modelManufacturer')}
</SelectLabel>
{requesterNameList
.filter(
(item) =>
item.provider_category === 'manufacturer',
)
.map((item) => (
<SelectItem key={item.value} value={item.value}>
{item.label}
{requesterList
.filter((r) => r.category === 'manufacturer')
.map((r) => (
<SelectItem key={r.value} value={r.value}>
{r.label}
</SelectItem>
))}
</SelectGroup>
@@ -462,51 +424,36 @@ export default function EmbeddingForm({
<SelectLabel>
{t('models.aggregationPlatform')}
</SelectLabel>
{requesterNameList
.filter((item) => item.provider_category === 'maas')
.map((item) => (
<SelectItem key={item.value} value={item.value}>
{item.label}
{requesterList
.filter((r) => r.category === 'maas')
.map((r) => (
<SelectItem key={r.value} value={r.value}>
{r.label}
</SelectItem>
))}
</SelectGroup>
<SelectGroup>
<SelectLabel>{t('models.selfDeployed')}</SelectLabel>
{requesterNameList
.filter(
(item) =>
item.provider_category === 'self-hosted',
)
.map((item) => (
<SelectItem key={item.value} value={item.value}>
{item.label}
<SelectLabel>
{t('models.selfDeployed')}
</SelectLabel>
{requesterList
.filter((r) => r.category === 'self-hosted')
.map((r) => (
<SelectItem key={r.value} value={r.value}>
{r.label}
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
</FormControl>
{currentModelProvider &&
requesterNameList.find(
(item) => item.value === currentModelProvider,
)?.description && (
<FormDescription>
{
requesterNameList.find(
(item) => item.value === currentModelProvider,
)?.description
}
</FormDescription>
)}
<FormMessage />
</FormItem>
)}
/>
{!['seekdb-embedding'].includes(currentModelProvider) && (
<FormField
control={form.control}
name="url"
name="new_provider_url"
render={({ field }) => (
<FormItem>
<FormLabel>{t('models.requestURL')}</FormLabel>
@@ -517,25 +464,23 @@ export default function EmbeddingForm({
</FormItem>
)}
/>
)}
{!['ollama-chat', 'seekdb-embedding'].includes(
currentModelProvider,
) && (
<FormField
control={form.control}
name="api_key"
name="new_provider_api_key"
render={({ field }) => (
<FormItem>
<FormLabel>{t('models.apiKey')}</FormLabel>
<FormControl>
<Input {...field} />
<Input {...field} type="password" />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
)}
</TabsContent>
</Tabs>
</div>
<FormItem>
<FormLabel>{t('models.extraParameters')}</FormLabel>
@@ -551,12 +496,10 @@ export default function EmbeddingForm({
/>
<Select
value={arg.type}
onValueChange={(value) =>
updateExtraArg(index, 'type', value)
}
onValueChange={(v) => updateExtraArg(index, 'type', v)}
>
<SelectTrigger className="w-[120px] bg-[#ffffff] dark:bg-[#2a2a2e]">
<SelectValue placeholder={t('models.type')} />
<SelectTrigger className="w-[120px] bg-background">
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="string">
@@ -577,20 +520,14 @@ export default function EmbeddingForm({
updateExtraArg(index, 'value', e.target.value)
}
/>
<button
<Button
type="button"
className="p-2 hover:bg-gray-100 rounded"
variant="ghost"
size="icon"
onClick={() => removeExtraArg(index)}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
className="w-5 h-5 text-red-500"
>
<path d="M7 4V2H17V4H22V6H20V21C20 21.5523 19.5523 22 19 22H5C4.44772 22 4 21.5523 4 21V6H2V4H7ZM6 6V20H18V6H6ZM9 9H11V17H9V9ZM13 9H15V17H13V9Z"></path>
</svg>
</button>
<span className="text-red-500">×</span>
</Button>
</div>
))}
<Button type="button" variant="outline" onClick={addExtraArg}>
@@ -600,9 +537,8 @@ export default function EmbeddingForm({
<FormDescription>
{t('embedding.extraParametersDescription')}
</FormDescription>
<FormMessage />
</FormItem>
</div>
{testErrorMessage && (
<Alert variant="destructive">
<AlertCircle className="h-4 w-4" />
@@ -612,6 +548,7 @@ export default function EmbeddingForm({
</AlertDescription>
</Alert>
)}
<DialogFooter>
{editMode && (
<Button
@@ -622,25 +559,18 @@ export default function EmbeddingForm({
{t('common.delete')}
</Button>
)}
<Button type="submit">
{editMode ? t('common.save') : t('common.submit')}
</Button>
<Button
type="button"
variant="outline"
onClick={() => testEmbeddingModelInForm()}
onClick={testModel}
disabled={modelTesting}
>
{t('common.test')}
</Button>
<Button
type="button"
variant="outline"
onClick={() => onFormCancel()}
>
<Button type="button" variant="outline" onClick={onFormCancel}>
{t('common.cancel')}
</Button>
</DialogFooter>

View File

@@ -1,9 +1,6 @@
import { ICreateLLMField } from '../ICreateLLMField';
import { useEffect, useState } from 'react';
import { IChooseRequesterEntity } from '../ChooseRequesterEntity';
import { httpClient } from '@/app/infra/http/HttpClient';
import { LLMModel } from '@/app/infra/entities/api';
import { UUID } from 'uuidjs';
import { ModelProvider } from '@/app/infra/entities/api';
import { zodResolver } from '@hookform/resolvers/zod';
import { useForm } from 'react-hook-form';
@@ -43,60 +40,45 @@ import { toast } from 'sonner';
import { extractI18nObject } from '@/i18n/I18nProvider';
import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert';
import { AlertCircle } from 'lucide-react';
const getExtraArgSchema = (t: (key: string) => string) =>
z
.object({
key: z.string().min(1, { message: t('models.keyNameRequired') }),
type: z.enum(['string', 'number', 'boolean']),
value: z.string(),
})
.superRefine((data, ctx) => {
if (data.type === 'number' && isNaN(Number(data.value))) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
message: t('models.mustBeValidNumber'),
path: ['value'],
});
}
if (
data.type === 'boolean' &&
data.value !== 'true' &&
data.value !== 'false'
) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
message: t('models.mustBeTrueOrFalse'),
path: ['value'],
});
}
});
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs';
const getFormSchema = (t: (key: string) => string) =>
z.object({
name: z.string().min(1, { message: t('models.modelNameRequired') }),
model_provider: z
.string()
.min(1, { message: t('models.modelProviderRequired') }),
url: z.string().min(1, { message: t('models.requestURLRequired') }),
api_key: z.string().optional(),
provider_uuid: z.string().optional(),
// New provider fields
new_provider_requester: z.string().optional(),
new_provider_url: z.string().optional(),
new_provider_api_key: z.string().optional(),
abilities: z.array(z.string()),
extra_args: z.array(getExtraArgSchema(t)).optional(),
extra_args: z
.array(
z.object({
key: z.string(),
type: z.enum(['string', 'number', 'boolean']),
value: z.string(),
}),
)
.optional(),
});
interface LLMFormProps {
editMode: boolean;
initLLMId?: string;
providers: ModelProvider[];
onFormSubmit: () => void;
onFormCancel: () => void;
onLLMDeleted: () => void;
}
export default function LLMForm({
editMode,
initLLMId,
providers,
onFormSubmit,
onFormCancel,
onLLMDeleted,
}: {
editMode: boolean;
initLLMId?: string;
onFormSubmit: () => void;
onFormCancel: () => void;
onLLMDeleted: () => void;
}) {
}: LLMFormProps) {
const { t } = useTranslation();
const formSchema = getFormSchema(t);
@@ -104,9 +86,10 @@ export default function LLMForm({
resolver: zodResolver(formSchema),
defaultValues: {
name: '',
model_provider: '',
url: '',
api_key: '',
provider_uuid: '',
new_provider_requester: '',
new_provider_url: '',
new_provider_api_key: '',
abilities: [],
extra_args: [],
},
@@ -115,69 +98,186 @@ export default function LLMForm({
const [extraArgs, setExtraArgs] = useState<
{ key: string; type: 'string' | 'number' | 'boolean'; value: string }[]
>([]);
const [showDeleteConfirmModal, setShowDeleteConfirmModal] = useState(false);
const abilityOptions: { label: string; value: string }[] = [
{
label: t('models.visionAbility'),
value: 'vision',
},
{
label: t('models.functionCallAbility'),
value: 'func_call',
},
];
const [requesterNameList, setRequesterNameList] = useState<
IChooseRequesterEntity[]
>([]);
const [requesterDefaultURLList, setRequesterDefaultURLList] = useState<
string[]
>([]);
const [modelTesting, setModelTesting] = useState(false);
const [testErrorMessage, setTestErrorMessage] = useState<string | null>(null);
const [currentModelProvider, setCurrentModelProvider] = useState('');
const [providerMode, setProviderMode] = useState<'existing' | 'new'>(
'existing',
);
const [requesterList, setRequesterList] = useState<
{ label: string; value: string; category: string; defaultUrl: string }[]
>([]);
const abilityOptions = [
{ label: t('models.visionAbility'), value: 'vision' },
{ label: t('models.functionCallAbility'), value: 'func_call' },
];
useEffect(() => {
initLLMModelFormComponent().then(() => {
loadRequesters();
if (editMode && initLLMId) {
getLLMConfig(initLLMId).then((val) => {
form.setValue('name', val.name);
form.setValue('model_provider', val.model_provider);
setCurrentModelProvider(val.model_provider);
form.setValue('url', val.url);
form.setValue('api_key', val.api_key);
form.setValue(
'abilities',
val.abilities as ('vision' | 'func_call')[],
);
// 转换extra_args为新格式
if (val.extra_args) {
const args = val.extra_args.map((arg) => {
const [key, value] = arg.split(':');
let type: 'string' | 'number' | 'boolean' = 'string';
if (!isNaN(Number(value))) {
type = 'number';
} else if (value === 'true' || value === 'false') {
type = 'boolean';
loadModel(initLLMId);
}
return {
key,
type,
value,
};
}, [editMode, initLLMId]);
async function loadRequesters() {
const resp = await httpClient.getProviderRequesters('llm');
setRequesterList(
resp.requesters.map((item) => ({
label: extractI18nObject(item.label),
value: item.name,
category: item.spec.provider_category || 'manufacturer',
defaultUrl:
item.spec.config
.find((c) => c.name === 'base_url')
?.default?.toString() || '',
})),
);
}
async function loadModel(id: string) {
const resp = await httpClient.getProviderLLMModel(id);
const model = resp.model;
form.setValue('name', model.name);
form.setValue('provider_uuid', model.provider_uuid);
form.setValue('abilities', model.abilities || []);
if (model.extra_args) {
const args = Object.entries(model.extra_args).map(([key, value]) => {
let type: 'string' | 'number' | 'boolean' = 'string';
if (typeof value === 'number') type = 'number';
else if (typeof value === 'boolean') type = 'boolean';
return { key, type, value: String(value) };
});
setExtraArgs(args);
form.setValue('extra_args', args);
}
});
} else {
form.reset();
setProviderMode('existing');
}
function handleFormSubmit(values: z.infer<typeof formSchema>) {
const extraArgsObj: Record<string, string | number | boolean> = {};
values.extra_args?.forEach((arg) => {
if (arg.type === 'number') extraArgsObj[arg.key] = Number(arg.value);
else if (arg.type === 'boolean')
extraArgsObj[arg.key] = arg.value === 'true';
else extraArgsObj[arg.key] = arg.value;
});
}, []);
const modelData: Record<string, unknown> = {
name: values.name,
abilities: values.abilities,
extra_args: extraArgsObj,
};
if (providerMode === 'existing' && values.provider_uuid) {
modelData.provider_uuid = values.provider_uuid;
} else if (providerMode === 'new') {
modelData.provider = {
requester: values.new_provider_requester,
base_url: values.new_provider_url,
api_keys: values.new_provider_api_key
? [values.new_provider_api_key]
: [],
};
}
if (editMode && initLLMId) {
updateModel(initLLMId, modelData);
} else {
createModel(modelData);
}
}
async function createModel(data: Record<string, unknown>) {
try {
await httpClient.createProviderLLMModel(data as never);
toast.success(t('models.createSuccess'));
onFormSubmit();
} catch (err) {
toast.error(t('models.createError') + (err as Error).message);
}
}
async function updateModel(id: string, data: Record<string, unknown>) {
try {
await httpClient.updateProviderLLMModel(id, data as never);
toast.success(t('models.saveSuccess'));
onFormSubmit();
} catch (err) {
toast.error(t('models.saveError') + (err as Error).message);
}
}
async function deleteModel() {
if (!initLLMId) return;
try {
await httpClient.deleteProviderLLMModel(initLLMId);
toast.success(t('models.deleteSuccess'));
onLLMDeleted();
} catch (err) {
toast.error(t('models.deleteError') + (err as Error).message);
}
}
async function testModel() {
setModelTesting(true);
setTestErrorMessage(null);
const values = form.getValues();
const extraArgsObj: Record<string, string | number | boolean> = {};
values.extra_args?.forEach((arg) => {
if (arg.type === 'number') extraArgsObj[arg.key] = Number(arg.value);
else if (arg.type === 'boolean')
extraArgsObj[arg.key] = arg.value === 'true';
else extraArgsObj[arg.key] = arg.value;
});
let provider: Record<string, unknown>;
if (providerMode === 'existing' && values.provider_uuid) {
const p = providers.find((p) => p.uuid === values.provider_uuid);
provider = {
requester: p?.requester || '',
base_url: p?.base_url || '',
api_keys: p?.api_keys || [],
};
} else {
provider = {
requester: values.new_provider_requester,
base_url: values.new_provider_url,
api_keys: values.new_provider_api_key
? [values.new_provider_api_key]
: [],
};
}
try {
await httpClient.testLLMModel('_', {
uuid: '',
name: values.name,
provider_uuid: '',
provider,
abilities: values.abilities,
extra_args: extraArgsObj,
} as never);
toast.success(t('models.testSuccess'));
} catch (err) {
setTestErrorMessage((err as Error).message || t('models.testError'));
} finally {
setModelTesting(false);
}
}
const addExtraArg = () => {
setExtraArgs([...extraArgs, { key: '', type: 'string', value: '' }]);
const newArgs = [
...extraArgs,
{ key: '', type: 'string' as const, value: '' },
];
setExtraArgs(newArgs);
form.setValue('extra_args', newArgs);
};
const updateExtraArg = (
@@ -186,10 +286,7 @@ export default function LLMForm({
value: string,
) => {
const newArgs = [...extraArgs];
newArgs[index] = {
...newArgs[index],
[field]: value,
};
newArgs[index] = { ...newArgs[index], [field]: value };
setExtraArgs(newArgs);
form.setValue('extra_args', newArgs);
};
@@ -200,163 +297,6 @@ export default function LLMForm({
form.setValue('extra_args', newArgs);
};
async function initLLMModelFormComponent() {
const requesterNameList = await httpClient.getProviderRequesters('llm');
setRequesterNameList(
requesterNameList.requesters.map((item) => {
return {
label: extractI18nObject(item.label),
value: item.name,
provider_category: item.spec.provider_category || 'manufacturer',
};
}),
);
setRequesterDefaultURLList(
requesterNameList.requesters.map((item) => {
const config = item.spec.config;
for (let i = 0; i < config.length; i++) {
if (config[i].name == 'base_url') {
return config[i].default?.toString() || '';
}
}
return '';
}),
);
}
async function getLLMConfig(id: string): Promise<ICreateLLMField> {
const llmModel = await httpClient.getProviderLLMModel(id);
const fakeExtraArgs = [];
const extraArgs = llmModel.model.extra_args as Record<string, string>;
for (const key in extraArgs) {
fakeExtraArgs.push(`${key}:${extraArgs[key]}`);
}
return {
name: llmModel.model.name,
model_provider: llmModel.model.requester,
url: llmModel.model.requester_config?.base_url,
api_key: llmModel.model.api_keys[0],
abilities: llmModel.model.abilities || [],
extra_args: fakeExtraArgs,
};
}
function handleFormSubmit(value: z.infer<typeof formSchema>) {
const extraArgsObj: Record<string, string | number | boolean> = {};
value.extra_args?.forEach(
(arg: { key: string; type: string; value: string }) => {
if (arg.type === 'number') {
extraArgsObj[arg.key] = Number(arg.value);
} else if (arg.type === 'boolean') {
extraArgsObj[arg.key] = arg.value === 'true';
} else {
extraArgsObj[arg.key] = arg.value;
}
},
);
const llmModel: LLMModel = {
uuid: editMode ? initLLMId || '' : UUID.generate(),
name: value.name,
description: '',
requester: value.model_provider,
requester_config: {
base_url: value.url,
timeout: 120,
},
extra_args: extraArgsObj,
api_keys: value.api_key ? [value.api_key] : [],
abilities: value.abilities,
};
if (editMode) {
onSaveEdit(llmModel).then(() => {
form.reset();
});
} else {
onCreateLLM(llmModel).then(() => {
form.reset();
});
}
}
async function onCreateLLM(llmModel: LLMModel) {
try {
await httpClient.createProviderLLMModel(llmModel);
onFormSubmit();
toast.success(t('models.createSuccess'));
} catch (err) {
toast.error(t('models.createError') + (err as Error).message);
}
}
async function onSaveEdit(llmModel: LLMModel) {
try {
await httpClient.updateProviderLLMModel(initLLMId || '', llmModel);
onFormSubmit();
toast.success(t('models.saveSuccess'));
} catch (err) {
toast.error(t('models.saveError') + (err as Error).message);
}
}
function deleteModel() {
if (initLLMId) {
httpClient
.deleteProviderLLMModel(initLLMId)
.then(() => {
onLLMDeleted();
toast.success(t('models.deleteSuccess'));
})
.catch((err) => {
toast.error(t('models.deleteError') + err.message);
});
}
}
function testLLMModelInForm() {
setModelTesting(true);
setTestErrorMessage(null);
const extraArgsObj: Record<string, string | number | boolean> = {};
form
.getValues('extra_args')
?.forEach((arg: { key: string; type: string; value: string }) => {
if (arg.type === 'number') {
extraArgsObj[arg.key] = Number(arg.value);
} else if (arg.type === 'boolean') {
extraArgsObj[arg.key] = arg.value === 'true';
} else {
extraArgsObj[arg.key] = arg.value;
}
});
const apiKey = form.getValues('api_key');
httpClient
.testLLMModel('_', {
uuid: '',
name: form.getValues('name'),
description: '',
requester: form.getValues('model_provider'),
requester_config: {
base_url: form.getValues('url'),
timeout: 120,
},
api_keys: apiKey ? [apiKey] : [],
abilities: form.getValues('abilities'),
extra_args: extraArgsObj,
})
.then(() => {
toast.success(t('models.testSuccess'));
setTestErrorMessage(null);
})
.catch((err: { message?: string }) => {
setTestErrorMessage(err?.message || t('models.testError'));
})
.finally(() => {
setModelTesting(false);
});
}
return (
<div>
<Dialog
@@ -393,9 +333,8 @@ export default function LLMForm({
<Form {...form}>
<form
onSubmit={form.handleSubmit(handleFormSubmit)}
className="space-y-8"
className="space-y-6"
>
<div className="space-y-4">
<FormField
control={form.control}
name="name"
@@ -406,42 +345,78 @@ export default function LLMForm({
<span className="text-red-500">*</span>
</FormLabel>
<FormControl>
<Input {...field} />
<Input {...field} placeholder="gpt-4o" />
</FormControl>
<FormMessage />
<FormDescription>
{t('models.modelProviderDescription')}
</FormDescription>
<FormMessage />
</FormItem>
)}
/>
<div>
<FormLabel>{t('models.provider')}</FormLabel>
<Tabs
value={providerMode}
onValueChange={(v) => setProviderMode(v as 'existing' | 'new')}
className="mt-2"
>
<TabsList>
<TabsTrigger value="existing">
{t('models.existingProvider')}
</TabsTrigger>
<TabsTrigger value="new">{t('models.newProvider')}</TabsTrigger>
</TabsList>
<TabsContent value="existing" className="mt-3">
<FormField
control={form.control}
name="model_provider"
name="provider_uuid"
render={({ field }) => (
<FormItem>
<FormLabel>
{t('models.modelProvider')}
<span className="text-red-500">*</span>
</FormLabel>
<FormControl>
<Select
onValueChange={(value) => {
field.onChange(value);
setCurrentModelProvider(value);
const index = requesterNameList.findIndex(
(item) => item.value === value,
);
if (index !== -1) {
form.setValue('url', requesterDefaultURLList[index]);
}
onValueChange={field.onChange}
value={field.value}
>
<SelectTrigger className="bg-background">
<SelectValue
placeholder={t('models.selectProvider')}
/>
</SelectTrigger>
<SelectContent>
{providers.map((p) => (
<SelectItem key={p.uuid} value={p.uuid}>
{p.name} ({p.base_url || 'default'})
</SelectItem>
))}
</SelectContent>
</Select>
<FormMessage />
</FormItem>
)}
/>
</TabsContent>
<TabsContent value="new" className="mt-3 space-y-4">
<FormField
control={form.control}
name="new_provider_requester"
render={({ field }) => (
<FormItem>
<FormLabel>{t('models.requester')}</FormLabel>
<Select
onValueChange={(v) => {
field.onChange(v);
const req = requesterList.find((r) => r.value === v);
if (req)
form.setValue('new_provider_url', req.defaultUrl);
}}
value={field.value}
>
<SelectTrigger className="w-[180px] bg-[#ffffff] dark:bg-[#2a2a2e]">
<SelectTrigger className="bg-background">
<SelectValue
placeholder={t('models.selectModelProvider')}
placeholder={t('models.selectRequester')}
/>
</SelectTrigger>
<SelectContent>
@@ -449,14 +424,11 @@ export default function LLMForm({
<SelectLabel>
{t('models.modelManufacturer')}
</SelectLabel>
{requesterNameList
.filter(
(item) =>
item.provider_category === 'manufacturer',
)
.map((item) => (
<SelectItem key={item.value} value={item.value}>
{item.label}
{requesterList
.filter((r) => r.category === 'manufacturer')
.map((r) => (
<SelectItem key={r.value} value={r.value}>
{r.label}
</SelectItem>
))}
</SelectGroup>
@@ -464,30 +436,28 @@ export default function LLMForm({
<SelectLabel>
{t('models.aggregationPlatform')}
</SelectLabel>
{requesterNameList
.filter((item) => item.provider_category === 'maas')
.map((item) => (
<SelectItem key={item.value} value={item.value}>
{item.label}
{requesterList
.filter((r) => r.category === 'maas')
.map((r) => (
<SelectItem key={r.value} value={r.value}>
{r.label}
</SelectItem>
))}
</SelectGroup>
<SelectGroup>
<SelectLabel>{t('models.selfDeployed')}</SelectLabel>
{requesterNameList
.filter(
(item) =>
item.provider_category === 'self-hosted',
)
.map((item) => (
<SelectItem key={item.value} value={item.value}>
{item.label}
<SelectLabel>
{t('models.selfDeployed')}
</SelectLabel>
{requesterList
.filter((r) => r.category === 'self-hosted')
.map((r) => (
<SelectItem key={r.value} value={r.value}>
{r.label}
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
</FormControl>
<FormMessage />
</FormItem>
)}
@@ -495,13 +465,10 @@ export default function LLMForm({
<FormField
control={form.control}
name="url"
name="new_provider_url"
render={({ field }) => (
<FormItem>
<FormLabel>
{t('models.requestURL')}
<span className="text-red-500">*</span>
</FormLabel>
<FormLabel>{t('models.requestURL')}</FormLabel>
<FormControl>
<Input {...field} />
</FormControl>
@@ -510,23 +477,22 @@ export default function LLMForm({
)}
/>
{!['lmstudio-chat-completions', 'ollama-chat'].includes(
currentModelProvider,
) && (
<FormField
control={form.control}
name="api_key"
name="new_provider_api_key"
render={({ field }) => (
<FormItem>
<FormLabel>{t('models.apiKey')}</FormLabel>
<FormControl>
<Input {...field} />
<Input {...field} type="password" />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
)}
</TabsContent>
</Tabs>
</div>
<FormField
control={form.control}
@@ -534,40 +500,32 @@ export default function LLMForm({
render={() => (
<FormItem>
<FormLabel>{t('models.abilities')}</FormLabel>
<div className="mb-0">
<FormDescription>
{t('models.selectModelAbilities')}
</FormDescription>
</div>
{abilityOptions.map((item) => (
<FormField
key={item.value}
control={form.control}
name="abilities"
render={({ field }) => {
return (
<FormItem
key={item.value}
className="flex flex-row items-start space-x-1 space-y-0"
>
render={({ field }) => (
<FormItem className="flex flex-row items-start space-x-2 space-y-0">
<FormControl>
<Checkbox
checked={
Array.isArray(field.value) &&
field.value?.includes(item.value)
}
checked={field.value?.includes(item.value)}
onCheckedChange={(checked) => {
return checked
? field.onChange([
if (checked) {
field.onChange([
...(field.value || []),
item.value,
])
: field.onChange(
]);
} else {
field.onChange(
field.value?.filter(
(value: string) =>
value !== item.value,
(v: string) => v !== item.value,
),
);
}
}}
/>
</FormControl>
@@ -575,11 +533,9 @@ export default function LLMForm({
{item.label}
</FormLabel>
</FormItem>
);
}}
)}
/>
))}
<FormMessage />
</FormItem>
)}
/>
@@ -598,12 +554,10 @@ export default function LLMForm({
/>
<Select
value={arg.type}
onValueChange={(value) =>
updateExtraArg(index, 'type', value)
}
onValueChange={(v) => updateExtraArg(index, 'type', v)}
>
<SelectTrigger className="w-[120px] bg-[#ffffff] dark:bg-[#2a2a2e]">
<SelectValue placeholder={t('models.type')} />
<SelectTrigger className="w-[120px] bg-background">
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="string">
@@ -624,20 +578,14 @@ export default function LLMForm({
updateExtraArg(index, 'value', e.target.value)
}
/>
<button
<Button
type="button"
className="p-2 hover:bg-gray-100 rounded"
variant="ghost"
size="icon"
onClick={() => removeExtraArg(index)}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
className="w-5 h-5 text-red-500"
>
<path d="M7 4V2H17V4H22V6H20V21C20 21.5523 19.5523 22 19 22H5C4.44772 22 4 21.5523 4 21V6H2V4H7ZM6 6V20H18V6H6ZM9 9H11V17H9V9ZM13 9H15V17H13V9Z"></path>
</svg>
</button>
<span className="text-red-500">×</span>
</Button>
</div>
))}
<Button type="button" variant="outline" onClick={addExtraArg}>
@@ -647,9 +595,8 @@ export default function LLMForm({
<FormDescription>
{t('llm.extraParametersDescription')}
</FormDescription>
<FormMessage />
</FormItem>
</div>
{testErrorMessage && (
<Alert variant="destructive">
<AlertCircle className="h-4 w-4" />
@@ -659,6 +606,7 @@ export default function LLMForm({
</AlertDescription>
</Alert>
)}
<DialogFooter>
{editMode && (
<Button
@@ -669,25 +617,18 @@ export default function LLMForm({
{t('common.delete')}
</Button>
)}
<Button type="submit">
{editMode ? t('common.save') : t('common.submit')}
</Button>
<Button
type="button"
variant="outline"
onClick={() => testLLMModelInForm()}
onClick={testModel}
disabled={modelTesting}
>
{t('common.test')}
</Button>
<Button
type="button"
variant="outline"
onClick={() => onFormCancel()}
>
<Button type="button" variant="outline" onClick={onFormCancel}>
{t('common.cancel')}
</Button>
</DialogFooter>

View File

@@ -0,0 +1,242 @@
import { useEffect, useState } from 'react';
import { httpClient } from '@/app/infra/http/HttpClient';
import { zodResolver } from '@hookform/resolvers/zod';
import { useForm } from 'react-hook-form';
import { z } from 'zod';
import { useTranslation } from 'react-i18next';
import { Button } from '@/components/ui/button';
import {
Form,
FormControl,
FormField,
FormItem,
FormLabel,
FormMessage,
} from '@/components/ui/form';
import { Input } from '@/components/ui/input';
import {
Select,
SelectContent,
SelectGroup,
SelectItem,
SelectLabel,
SelectTrigger,
SelectValue,
} from '@/components/ui/select';
import { DialogFooter } from '@/components/ui/dialog';
import { toast } from 'sonner';
import { extractI18nObject } from '@/i18n/I18nProvider';
const getFormSchema = (t: (key: string) => string) =>
z.object({
name: z.string().min(1, { message: t('models.providerNameRequired') }),
requester: z.string().min(1, { message: t('models.requesterRequired') }),
base_url: z.string(),
api_key: z.string().optional(),
});
interface ProviderFormProps {
providerId?: string;
onFormSubmit: () => void;
onFormCancel: () => void;
}
export default function ProviderForm({
providerId,
onFormSubmit,
onFormCancel,
}: ProviderFormProps) {
const { t } = useTranslation();
const formSchema = getFormSchema(t);
const form = useForm<z.infer<typeof formSchema>>({
resolver: zodResolver(formSchema),
defaultValues: {
name: '',
requester: '',
base_url: '',
api_key: '',
},
});
const [requesterList, setRequesterList] = useState<
{ label: string; value: string; category: string; defaultUrl: string }[]
>([]);
useEffect(() => {
loadRequesters();
if (providerId) {
loadProvider(providerId);
}
}, [providerId]);
async function loadRequesters() {
const resp = await httpClient.getProviderRequesters('llm');
setRequesterList(
resp.requesters.map((item) => ({
label: extractI18nObject(item.label),
value: item.name,
category: item.spec.provider_category || 'manufacturer',
defaultUrl:
item.spec.config
.find((c) => c.name === 'base_url')
?.default?.toString() || '',
})),
);
}
async function loadProvider(id: string) {
const resp = await httpClient.getModelProvider(id);
const provider = resp.provider;
form.setValue('name', provider.name);
form.setValue('requester', provider.requester);
form.setValue('base_url', provider.base_url);
form.setValue('api_key', provider.api_keys?.[0] || '');
}
async function handleFormSubmit(values: z.infer<typeof formSchema>) {
const data = {
name: values.name,
requester: values.requester,
base_url: values.base_url,
api_keys: values.api_key ? [values.api_key] : [],
};
try {
if (providerId) {
await httpClient.updateModelProvider(providerId, data);
toast.success(t('models.providerSaved'));
} else {
await httpClient.createModelProvider(data);
toast.success(t('models.providerCreated'));
}
onFormSubmit();
} catch (err) {
toast.error(t('models.providerSaveError') + (err as Error).message);
}
}
return (
<Form {...form}>
<form
onSubmit={form.handleSubmit(handleFormSubmit)}
className="space-y-4"
>
<FormField
control={form.control}
name="name"
render={({ field }) => (
<FormItem>
<FormLabel>
{t('models.providerName')}
<span className="text-red-500">*</span>
</FormLabel>
<FormControl>
<Input {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="requester"
render={({ field }) => (
<FormItem>
<FormLabel>
{t('models.requester')}
<span className="text-red-500">*</span>
</FormLabel>
<Select
onValueChange={(v) => {
field.onChange(v);
const req = requesterList.find((r) => r.value === v);
if (req && !form.getValues('base_url')) {
form.setValue('base_url', req.defaultUrl);
}
}}
value={field.value}
>
<SelectTrigger className="bg-background">
<SelectValue placeholder={t('models.selectRequester')} />
</SelectTrigger>
<SelectContent>
<SelectGroup>
<SelectLabel>{t('models.modelManufacturer')}</SelectLabel>
{requesterList
.filter((r) => r.category === 'manufacturer')
.map((r) => (
<SelectItem key={r.value} value={r.value}>
{r.label}
</SelectItem>
))}
</SelectGroup>
<SelectGroup>
<SelectLabel>{t('models.aggregationPlatform')}</SelectLabel>
{requesterList
.filter((r) => r.category === 'maas')
.map((r) => (
<SelectItem key={r.value} value={r.value}>
{r.label}
</SelectItem>
))}
</SelectGroup>
<SelectGroup>
<SelectLabel>{t('models.selfDeployed')}</SelectLabel>
{requesterList
.filter((r) => r.category === 'self-hosted')
.map((r) => (
<SelectItem key={r.value} value={r.value}>
{r.label}
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="base_url"
render={({ field }) => (
<FormItem>
<FormLabel>{t('models.requestURL')}</FormLabel>
<FormControl>
<Input {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="api_key"
render={({ field }) => (
<FormItem>
<FormLabel>{t('models.apiKey')}</FormLabel>
<FormControl>
<Input {...field} type="password" />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<DialogFooter>
<Button type="submit">{t('common.save')}</Button>
<Button type="button" variant="outline" onClick={onFormCancel}>
{t('common.cancel')}
</Button>
</DialogFooter>
</form>
</Form>
);
}

View File

@@ -19,16 +19,12 @@ import {
SelectContent,
SelectGroup,
SelectItem,
SelectLabel,
SelectTrigger,
SelectValue,
} from '@/components/ui/select';
import { KnowledgeBase, EmbeddingModel } from '@/app/infra/entities/api';
import { toast } from 'sonner';
import {
HoverCard,
HoverCardContent,
HoverCardTrigger,
} from '@/components/ui/hover-card';
const getFormSchema = (t: (key: string) => string) =>
z.object({
@@ -205,90 +201,35 @@ export default function KBForm({
/>
</SelectTrigger>
<SelectContent className="fixed z-[1000]">
<SelectGroup>
{embeddingModels.map((model) => (
<HoverCard
{(() => {
const grouped = embeddingModels.reduce(
(acc, model) => {
const providerName =
model.provider?.name ||
model.provider?.requester ||
'Unknown';
if (!acc[providerName]) acc[providerName] = [];
acc[providerName].push(model);
return acc;
},
{} as Record<string, EmbeddingModel[]>,
);
return Object.entries(grouped).map(
([providerName, models]) => (
<SelectGroup key={providerName}>
<SelectLabel>{providerName}</SelectLabel>
{models.map((model) => (
<SelectItem
key={model.uuid}
openDelay={0}
closeDelay={0}
value={model.uuid}
>
<HoverCardTrigger asChild>
<SelectItem value={model.uuid}>
{model.name}
</SelectItem>
</HoverCardTrigger>
<HoverCardContent
className="w-80 data-[state=open]:animate-none data-[state=closed]:animate-none"
align="end"
side="right"
sideOffset={10}
>
<div className="space-y-2">
<div className="flex items-center gap-2">
<img
src={httpClient.getProviderRequesterIconURL(
model.requester,
)}
alt="icon"
className="w-8 h-8 rounded-[8%]"
/>
<h4 className="font-medium">
{model.name}
</h4>
</div>
<p className="text-sm text-muted-foreground">
{model.description}
</p>
{model.requester_config && (
<div className="flex items-center gap-1 text-xs">
<svg
className="w-4 h-4 text-gray-500"
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M13.0607 8.11097L14.4749 9.52518C17.2086 12.2589 17.2086 16.691 14.4749 19.4247L14.1214 19.7782C11.3877 22.5119 6.95555 22.5119 4.22188 19.7782C1.48821 17.0446 1.48821 12.6124 4.22188 9.87874L5.6361 11.293C3.68348 13.2456 3.68348 16.4114 5.6361 18.364C7.58872 20.3166 10.7545 20.3166 12.7072 18.364L13.0607 18.0105C15.0133 16.0578 15.0133 12.892 13.0607 10.9394L11.6465 9.52518L13.0607 8.11097ZM19.7782 14.1214L18.364 12.7072C20.3166 10.7545 20.3166 7.58872 18.364 5.6361C16.4114 3.68348 13.2456 3.68348 11.293 5.6361L10.9394 5.98965C8.98678 7.94227 8.98678 11.1081 10.9394 13.0607L12.3536 14.4749L10.9394 15.8891L9.52518 14.4749C6.79151 11.7413 6.79151 7.30911 9.52518 4.57544L9.87874 4.22188C12.6124 1.48821 17.0446 1.48821 19.7782 4.22188C22.5119 6.95555 22.5119 11.3877 19.7782 14.1214Z"></path>
</svg>
<span className="font-semibold">
Base URL
</span>
{model.requester_config.base_url}
</div>
)}
{model.extra_args &&
Object.keys(model.extra_args).length >
0 && (
<div className="text-xs">
<div className="font-semibold mb-1">
{t('models.extraParameters')}
</div>
<div className="space-y-1">
{Object.entries(
model.extra_args as Record<
string,
unknown
>,
).map(([key, value]) => (
<div
key={key}
className="flex items-center gap-1"
>
<span className="text-gray-500">
{key}
</span>
<span className="break-all">
{JSON.stringify(value)}
</span>
</div>
))}
</div>
</div>
)}
</div>
</HoverCardContent>
</HoverCard>
))}
</SelectGroup>
),
);
})()}
</SelectContent>
</Select>
</div>

View File

@@ -41,20 +41,33 @@ export interface ApiRespProviderLLMModel {
model: LLMModel;
}
export interface LLMModel {
name: string;
description: string;
export interface ModelProvider {
uuid: string;
name: string;
requester: string;
requester_config: {
base_url: string;
timeout: number;
};
extra_args?: object;
api_keys: string[];
llm_count?: number;
embedding_count?: number;
created_at?: string;
updated_at?: string;
}
export interface ApiRespModelProviders {
providers: ModelProvider[];
}
export interface ApiRespModelProvider {
provider: ModelProvider;
}
export interface LLMModel {
uuid: string;
name: string;
provider_uuid: string;
provider?: ModelProvider;
abilities?: string[];
// created_at: string;
// updated_at: string;
extra_args?: object;
}
export interface KnowledgeBase {
@@ -76,18 +89,11 @@ export interface ApiRespProviderEmbeddingModel {
}
export interface EmbeddingModel {
name: string;
description: string;
uuid: string;
requester: string;
requester_config: {
base_url: string;
timeout: number;
};
name: string;
provider_uuid: string;
provider?: ModelProvider;
extra_args?: object;
api_keys: string[];
// created_at: string;
// updated_at: string;
}
export interface ApiRespPipelines {

View File

@@ -38,6 +38,9 @@ import {
ExternalKnowledgeBase,
ApiRespExternalKnowledgeBases,
ApiRespExternalKnowledgeBase,
ApiRespModelProviders,
ApiRespModelProvider,
ModelProvider,
} from '@/app/infra/entities/api';
import { Plugin } from '@/app/infra/entities/plugin';
import { GetBotLogsRequest } from '@/app/infra/http/requestParam/bots/GetBotLogsRequest';
@@ -65,7 +68,6 @@ export class BackendClient extends BaseHttpClient {
public getProviderRequesterIconURL(name: string): string {
if (this.instance.defaults.baseURL === '/') {
// 获取用户访问的URL
const url = window.location.href;
const baseURL = url.split('/').slice(0, 3).join('/');
return `${baseURL}/api/v1/provider/requesters/${name}/icon`;
@@ -76,9 +78,38 @@ export class BackendClient extends BaseHttpClient {
);
}
// ============ Model Providers ============
public getModelProviders(): Promise<ApiRespModelProviders> {
return this.get('/api/v1/provider/providers');
}
public getModelProvider(uuid: string): Promise<ApiRespModelProvider> {
return this.get(`/api/v1/provider/providers/${uuid}`);
}
public createModelProvider(
provider: Omit<ModelProvider, 'uuid'>,
): Promise<{ uuid: string }> {
return this.post('/api/v1/provider/providers', provider);
}
public updateModelProvider(
uuid: string,
provider: Partial<ModelProvider>,
): Promise<object> {
return this.put(`/api/v1/provider/providers/${uuid}`, provider);
}
public deleteModelProvider(uuid: string): Promise<object> {
return this.delete(`/api/v1/provider/providers/${uuid}`);
}
// ============ Provider Model LLM ============
public getProviderLLMModels(): Promise<ApiRespProviderLLMModels> {
return this.get('/api/v1/provider/models/llm');
public getProviderLLMModels(
providerUuid?: string,
): Promise<ApiRespProviderLLMModels> {
const params = providerUuid ? { provider_uuid: providerUuid } : {};
return this.get('/api/v1/provider/models/llm', params);
}
public getProviderLLMModel(uuid: string): Promise<ApiRespProviderLLMModel> {
@@ -105,8 +136,11 @@ export class BackendClient extends BaseHttpClient {
}
// ============ Provider Model Embedding ============
public getProviderEmbeddingModels(): Promise<ApiRespProviderEmbeddingModels> {
return this.get('/api/v1/provider/models/embedding');
public getProviderEmbeddingModels(
providerUuid?: string,
): Promise<ApiRespProviderEmbeddingModels> {
const params = providerUuid ? { provider_uuid: providerUuid } : {};
return this.get('/api/v1/provider/models/embedding', params);
}
public getProviderEmbeddingModel(
@@ -716,61 +750,4 @@ export class BackendClient extends BaseHttpClient {
}> {
return this.post('/api/v1/user/space/callback', { code });
}
// ============ Space Models Sync API ============
public syncSpaceModels(spaceUrl?: string): Promise<{
created_llm: number;
updated_llm: number;
created_embedding: number;
updated_embedding: number;
skipped: number;
}> {
return this.post('/api/v1/space/models/sync', { space_url: spaceUrl });
}
public getSpaceModels(): Promise<{
llm_models: Array<{
uuid: string;
name: string;
description: string;
requester: string;
space_model_id: string;
source: string;
}>;
embedding_models: Array<{
uuid: string;
name: string;
description: string;
requester: string;
space_model_id: string;
source: string;
}>;
}> {
return this.get('/api/v1/space/models');
}
public deleteSpaceModels(): Promise<{
deleted_llm: number;
deleted_embedding: number;
}> {
return this.delete('/api/v1/space/models');
}
public getAvailableSpaceModels(spaceUrl?: string): Promise<{
models: Array<{
model_id: string;
display_name: { [key: string]: string };
description: { [key: string]: string };
category: string;
provider: string;
}>;
vendors: Array<{
id: number;
name: string;
}>;
total: number;
}> {
const params = spaceUrl ? { space_url: spaceUrl } : {};
return this.get('/api/v1/space/models/available', params);
}
}

View File

@@ -186,6 +186,36 @@ const enUS = {
spaceModelReadOnly: 'Space models are read-only',
noSpaceModels: 'No Space models. Click Sync to fetch models from Space.',
noLocalModels: 'No local models. Click Create to add a model.',
// New keys for provider-based structure
addModel: 'Add Model',
addLLMModel: 'Add LLM Model',
addEmbeddingModel: 'Add Embedding Model',
provider: 'Provider',
existingProvider: 'Existing Provider',
newProvider: 'New Provider',
selectProvider: 'Select Provider',
requester: 'Requester',
selectRequester: 'Select Requester',
langbotModelsDescription: 'Cloud models powered by LangBot Space',
balance: 'Balance',
loginWithSpace: 'Login with Space',
loginToUseModels: 'Login with Space to use cloud models',
noModels: 'No models configured',
editProvider: 'Edit Provider',
providerName: 'Provider Name',
providerNameRequired: 'Provider name is required',
requesterRequired: 'Requester is required',
providerSaved: 'Provider saved',
providerCreated: 'Provider created',
providerSaveError: 'Failed to save provider: ',
providerDeleted: 'Provider deleted',
providerDeleteError: 'Failed to delete provider: ',
loadError: 'Failed to load data',
chat: 'Chat',
embedding: 'Embedding',
modelsCount: '{{count}} model(s)',
expandModels: 'Expand',
collapseModels: 'Collapse',
},
bots: {
title: 'Bots',

View File

@@ -192,6 +192,35 @@ const jaJP = {
'Space モデルがありません。同期ボタンをクリックして Space からモデルを取得してください。',
noLocalModels:
'ローカルモデルがありません。作成ボタンをクリックしてモデルを追加してください。',
addModel: 'モデルを追加',
addLLMModel: 'LLMモデルを追加',
addEmbeddingModel: '埋め込みモデルを追加',
provider: 'プロバイダー',
existingProvider: '既存のプロバイダー',
newProvider: '新規プロバイダー',
selectProvider: 'プロバイダーを選択',
requester: 'リクエスター',
selectRequester: 'リクエスターを選択',
langbotModelsDescription: 'LangBot Space が提供するクラウドモデル',
balance: '残高',
loginWithSpace: 'Space でログイン',
loginToUseModels: 'Space でログインしてクラウドモデルを使用',
noModels: 'モデルがありません',
editProvider: 'プロバイダーを編集',
providerName: 'プロバイダー名',
providerNameRequired: 'プロバイダー名は必須です',
requesterRequired: 'リクエスターは必須です',
providerSaved: 'プロバイダーを保存しました',
providerCreated: 'プロバイダーを作成しました',
providerSaveError: 'プロバイダーの保存に失敗しました:',
providerDeleted: 'プロバイダーを削除しました',
providerDeleteError: 'プロバイダーの削除に失敗しました:',
loadError: 'データの読み込みに失敗しました',
chat: 'チャット',
embedding: '埋め込み',
modelsCount: '{{count}} 個のモデル',
expandModels: '展開',
collapseModels: '折りたたむ',
},
bots: {
title: 'ボット',

View File

@@ -180,6 +180,36 @@ const zhHans = {
spaceModelReadOnly: 'Space 模型为只读',
noSpaceModels: '暂无 Space 模型。点击同步按钮从 Space 获取模型。',
noLocalModels: '暂无本地模型。点击创建按钮添加模型。',
// 供应商结构新增键
addModel: '添加模型',
addLLMModel: '添加对话模型',
addEmbeddingModel: '添加嵌入模型',
provider: '供应商',
existingProvider: '已有供应商',
newProvider: '新建供应商',
selectProvider: '选择供应商',
requester: '请求器',
selectRequester: '选择请求器',
langbotModelsDescription: 'LangBot Space 提供的云端模型',
balance: '余额',
loginWithSpace: '通过 Space 登录',
loginToUseModels: '通过 Space 登录以使用云端模型',
noModels: '暂无模型',
editProvider: '编辑供应商',
providerName: '供应商名称',
providerNameRequired: '供应商名称不能为空',
requesterRequired: '请求器不能为空',
providerSaved: '供应商已保存',
providerCreated: '供应商已创建',
providerSaveError: '保存供应商失败:',
providerDeleted: '供应商已删除',
providerDeleteError: '删除供应商失败:',
loadError: '加载数据失败',
chat: '对话',
embedding: '嵌入',
modelsCount: '{{count}} 个模型',
expandModels: '展开',
collapseModels: '收起',
},
bots: {
title: '机器人',

View File

@@ -180,6 +180,35 @@ const zhHant = {
spaceModelReadOnly: 'Space 模型為唯讀',
noSpaceModels: '暫無 Space 模型。點擊同步按鈕從 Space 取得模型。',
noLocalModels: '暫無本地模型。點擊建立按鈕新增模型。',
addModel: '新增模型',
addLLMModel: '新增對話模型',
addEmbeddingModel: '新增嵌入模型',
provider: '供應商',
existingProvider: '現有供應商',
newProvider: '新供應商',
selectProvider: '選擇供應商',
requester: '請求器',
selectRequester: '選擇請求器',
langbotModelsDescription: '由 LangBot Space 提供的雲端模型',
balance: '餘額',
loginWithSpace: '使用 Space 登入',
loginToUseModels: '使用 Space 登入以使用雲端模型',
noModels: '暫無模型',
editProvider: '編輯供應商',
providerName: '供應商名稱',
providerNameRequired: '供應商名稱不能為空',
requesterRequired: '請求器不能為空',
providerSaved: '供應商已儲存',
providerCreated: '供應商已建立',
providerSaveError: '儲存供應商失敗:',
providerDeleted: '供應商已刪除',
providerDeleteError: '刪除供應商失敗:',
loadError: '載入資料失敗',
chat: '對話',
embedding: '嵌入',
modelsCount: '{{count}} 個模型',
expandModels: '展開',
collapseModels: '收起',
},
bots: {
title: '機器人',