feat: refactor model management to introduce provider structure, enhancing model organization and retrieval

This commit is contained in:
Junyan Qin
2025-12-26 20:27:33 +08:00
parent 455e3db28d
commit 57fcec011d
24 changed files with 2676 additions and 2106 deletions

View File

@@ -9,12 +9,15 @@ class LLMModelsRouterGroup(group.RouterGroup):
@self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
async def _() -> str:
if quart.request.method == 'GET':
provider_uuid = quart.request.args.get('provider_uuid')
if provider_uuid:
return self.success(
data={'models': await self.ap.llm_model_service.get_llm_models_by_provider(provider_uuid)}
)
return self.success(data={'models': await self.ap.llm_model_service.get_llm_models()})
elif quart.request.method == 'POST':
json_data = await quart.request.json
model_uuid = await self.ap.llm_model_service.create_llm_model(json_data)
return self.success(data={'uuid': model_uuid})
@self.route('/<model_uuid>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
@@ -52,12 +55,19 @@ class EmbeddingModelsRouterGroup(group.RouterGroup):
@self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
async def _() -> str:
if quart.request.method == 'GET':
provider_uuid = quart.request.args.get('provider_uuid')
if provider_uuid:
return self.success(
data={
'models': await self.ap.embedding_models_service.get_embedding_models_by_provider(
provider_uuid
)
}
)
return self.success(data={'models': await self.ap.embedding_models_service.get_embedding_models()})
elif quart.request.method == 'POST':
json_data = await quart.request.json
model_uuid = await self.ap.embedding_models_service.create_embedding_model(json_data)
return self.success(data={'uuid': model_uuid})
@self.route('/<model_uuid>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)

View File

@@ -0,0 +1,45 @@
import quart
from ... import group
@group.group_class('models/providers', '/api/v1/provider/providers')
class ModelProvidersRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
async def _() -> str:
if quart.request.method == 'GET':
providers = await self.ap.provider_service.get_providers()
# Add model counts
for provider in providers:
counts = await self.ap.provider_service.get_provider_model_counts(provider['uuid'])
provider['llm_count'] = counts['llm_count']
provider['embedding_count'] = counts['embedding_count']
return self.success(data={'providers': providers})
elif quart.request.method == 'POST':
json_data = await quart.request.json
provider_uuid = await self.ap.provider_service.create_provider(json_data)
return self.success(data={'uuid': provider_uuid})
@self.route(
'/<provider_uuid>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY
)
async def _(provider_uuid: str) -> str:
if quart.request.method == 'GET':
provider = await self.ap.provider_service.get_provider(provider_uuid)
if provider is None:
return self.http_status(404, -1, 'provider not found')
counts = await self.ap.provider_service.get_provider_model_counts(provider_uuid)
provider['llm_count'] = counts['llm_count']
provider['embedding_count'] = counts['embedding_count']
return self.success(data={'provider': provider})
elif quart.request.method == 'PUT':
json_data = await quart.request.json
await self.ap.provider_service.update_provider(provider_uuid, json_data)
return self.success()
elif quart.request.method == 'DELETE':
try:
await self.ap.provider_service.delete_provider(provider_uuid)
return self.success()
except ValueError as e:
return self.http_status(400, -1, str(e))

View File

@@ -1,52 +0,0 @@
import quart
from .. import group
DEFAULT_SPACE_URL = 'https://space.langbot.app'
@group.group_class('space', '/api/v1/space')
class SpaceRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/models/sync', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _(user_email: str) -> str:
"""Sync models from Space MaaS to local database"""
json_data = await quart.request.json or {}
space_url = json_data.get('space_url', DEFAULT_SPACE_URL)
try:
stats = await self.ap.space_models_service.sync_models_from_space(user_email, space_url)
return self.success(data=stats)
except ValueError as e:
return self.fail(1, str(e))
except Exception as e:
return self.fail(2, f'Failed to sync models: {str(e)}')
@self.route('/models', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _(user_email: str) -> str:
"""Get all synced Space models"""
if quart.request.method == 'GET':
try:
models = await self.ap.space_models_service.get_space_models()
return self.success(data=models)
except Exception as e:
return self.fail(1, f'Failed to get Space models: {str(e)}')
elif quart.request.method == 'DELETE':
try:
stats = await self.ap.space_models_service.delete_space_models()
return self.success(data=stats)
except Exception as e:
return self.fail(1, f'Failed to delete Space models: {str(e)}')
@self.route('/models/available', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _(user_email: str) -> str:
"""Get available models from Space (preview before sync)"""
try:
space_url = quart.request.args.get('space_url', DEFAULT_SPACE_URL)
models_data = await self.ap.space_models_service.fetch_space_models(space_url)
return self.success(data=models_data)
except ValueError as e:
return self.fail(1, str(e))
except Exception as e:
return self.fail(2, f'Failed to fetch available models: {str(e)}')

View File

@@ -11,6 +11,18 @@ from ....entity.persistence import pipeline as persistence_pipeline
from ....provider.modelmgr import requester as model_requester
def _parse_provider_api_keys(provider_dict: dict) -> dict:
"""Parse api_keys if it's a JSON string"""
if isinstance(provider_dict.get('api_keys'), str):
import json
try:
provider_dict['api_keys'] = json.loads(provider_dict['api_keys'])
except Exception:
provider_dict['api_keys'] = []
return provider_dict
class LLMModelsService:
ap: app.Application
@@ -18,29 +30,64 @@ class LLMModelsService:
self.ap = ap
async def get_llm_models(self, include_secret: bool = True) -> list[dict]:
"""Get all LLM models with provider info"""
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
models = result.all()
masked_columns = []
if not include_secret:
masked_columns = ['api_keys']
# Get all providers for lookup
providers_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider)
)
providers = {p.uuid: p for p in providers_result.all()}
return [
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model, masked_columns)
for model in models
]
models_list = []
for model in models:
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
provider = providers.get(model.provider_uuid)
if provider:
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
provider_dict = _parse_provider_api_keys(provider_dict)
if not include_secret:
provider_dict['api_keys'] = ['***'] * len(provider_dict.get('api_keys', []))
model_dict['provider'] = provider_dict
models_list.append(model_dict)
return models_list
async def get_llm_models_by_provider(self, provider_uuid: str) -> list[dict]:
"""Get LLM models by provider UUID"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(
persistence_model.LLMModel.provider_uuid == provider_uuid
)
)
models = result.all()
return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, m) for m in models]
async def create_llm_model(self, model_data: dict) -> str:
"""Create a new LLM model"""
model_data['uuid'] = str(uuid.uuid4())
# Handle provider creation if needed
if 'provider' in model_data:
provider_data = model_data.pop('provider')
if provider_data.get('uuid'):
model_data['provider_uuid'] = provider_data['uuid']
else:
# Create new provider
provider_uuid = await self.ap.provider_service.find_or_create_provider(
requester=provider_data.get('requester', ''),
base_url=provider_data.get('base_url', ''),
api_keys=provider_data.get('api_keys', []),
)
model_data['provider_uuid'] = provider_uuid
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_model.LLMModel).values(**model_data))
llm_model = await self.get_llm_model(model_data['uuid'])
await self.ap.model_mgr.load_llm_model(llm_model)
# check if default pipeline has no model bound
# Check if default pipeline has no model bound
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.is_default == True
@@ -56,21 +103,47 @@ class LLMModelsService:
return model_data['uuid']
async def get_llm_model(self, model_uuid: str) -> dict | None:
"""Get a single LLM model with provider info"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
)
model = result.first()
if model is None:
return None
return self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
# Get provider
provider_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider).where(
persistence_model.ModelProvider.uuid == model.provider_uuid
)
)
provider = provider_result.first()
if provider:
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
return model_dict
async def update_llm_model(self, model_uuid: str, model_data: dict) -> None:
"""Update an existing LLM model"""
if 'uuid' in model_data:
del model_data['uuid']
# Handle provider update if needed
if 'provider' in model_data:
provider_data = model_data.pop('provider')
if provider_data.get('uuid'):
model_data['provider_uuid'] = provider_data['uuid']
else:
provider_uuid = await self.ap.provider_service.find_or_create_provider(
requester=provider_data.get('requester', ''),
base_url=provider_data.get('base_url', ''),
api_keys=provider_data.get('api_keys', []),
)
model_data['provider_uuid'] = provider_uuid
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.LLMModel)
.where(persistence_model.LLMModel.uuid == model_uuid)
@@ -78,19 +151,18 @@ class LLMModelsService:
)
await self.ap.model_mgr.remove_llm_model(model_uuid)
llm_model = await self.get_llm_model(model_uuid)
await self.ap.model_mgr.load_llm_model(llm_model)
async def delete_llm_model(self, model_uuid: str) -> None:
"""Delete an LLM model"""
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
)
await self.ap.model_mgr.remove_llm_model(model_uuid)
async def test_llm_model(self, model_uuid: str, model_data: dict) -> None:
"""Test an LLM model"""
runtime_llm_model: model_requester.RuntimeLLMModel | None = None
if model_uuid != '_':
@@ -98,18 +170,11 @@ class LLMModelsService:
if model.model_entity.uuid == model_uuid:
runtime_llm_model = model
break
if runtime_llm_model is None:
raise Exception('model not found')
else:
runtime_llm_model = await self.ap.model_mgr.init_runtime_llm_model(model_data)
# Mon Nov 10 2025: Commented for some providers may not support thinking parameter
# # 有些模型厂商默认开启了思考功能,测试容易延迟
# extra_args = model_data.get('extra_args', {})
# if not extra_args or 'thinking' not in extra_args:
# extra_args['thinking'] = {'type': 'disabled'}
extra_args = model_data.get('extra_args', {})
await runtime_llm_model.requester.invoke_llm(
query=None,
@@ -127,42 +192,103 @@ class EmbeddingModelsService:
self.ap = ap
async def get_embedding_models(self) -> list[dict]:
"""Get all embedding models with provider info"""
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel))
models = result.all()
return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) for model in models]
providers_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider)
)
providers = {p.uuid: p for p in providers_result.all()}
models_list = []
for model in models:
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model)
provider = providers.get(model.provider_uuid)
if provider:
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
models_list.append(model_dict)
return models_list
async def get_embedding_models_by_provider(self, provider_uuid: str) -> list[dict]:
"""Get embedding models by provider UUID"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.provider_uuid == provider_uuid
)
)
models = result.all()
return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, m) for m in models]
async def create_embedding_model(self, model_data: dict) -> str:
"""Create a new embedding model"""
model_data['uuid'] = str(uuid.uuid4())
if 'provider' in model_data:
provider_data = model_data.pop('provider')
if provider_data.get('uuid'):
model_data['provider_uuid'] = provider_data['uuid']
else:
provider_uuid = await self.ap.provider_service.find_or_create_provider(
requester=provider_data.get('requester', ''),
base_url=provider_data.get('base_url', ''),
api_keys=provider_data.get('api_keys', []),
)
model_data['provider_uuid'] = provider_uuid
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data)
)
embedding_model = await self.get_embedding_model(model_data['uuid'])
await self.ap.model_mgr.load_embedding_model(embedding_model)
return model_data['uuid']
async def get_embedding_model(self, model_uuid: str) -> dict | None:
"""Get a single embedding model with provider info"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.uuid == model_uuid
)
)
model = result.first()
if model is None:
return None
return self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model)
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model)
provider_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider).where(
persistence_model.ModelProvider.uuid == model.provider_uuid
)
)
provider = provider_result.first()
if provider:
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
return model_dict
async def update_embedding_model(self, model_uuid: str, model_data: dict) -> None:
"""Update an existing embedding model"""
if 'uuid' in model_data:
del model_data['uuid']
if 'provider' in model_data:
provider_data = model_data.pop('provider')
if provider_data.get('uuid'):
model_data['provider_uuid'] = provider_data['uuid']
else:
provider_uuid = await self.ap.provider_service.find_or_create_provider(
requester=provider_data.get('requester', ''),
base_url=provider_data.get('base_url', ''),
api_keys=provider_data.get('api_keys', []),
)
model_data['provider_uuid'] = provider_uuid
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.EmbeddingModel)
.where(persistence_model.EmbeddingModel.uuid == model_uuid)
@@ -170,21 +296,20 @@ class EmbeddingModelsService:
)
await self.ap.model_mgr.remove_embedding_model(model_uuid)
embedding_model = await self.get_embedding_model(model_uuid)
await self.ap.model_mgr.load_embedding_model(embedding_model)
async def delete_embedding_model(self, model_uuid: str) -> None:
"""Delete an embedding model"""
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.uuid == model_uuid
)
)
await self.ap.model_mgr.remove_embedding_model(model_uuid)
async def test_embedding_model(self, model_uuid: str, model_data: dict) -> None:
"""Test an embedding model"""
runtime_embedding_model: model_requester.RuntimeEmbeddingModel | None = None
if model_uuid != '_':
@@ -192,10 +317,8 @@ class EmbeddingModelsService:
if model.model_entity.uuid == model_uuid:
runtime_embedding_model = model
break
if runtime_embedding_model is None:
raise Exception('model not found')
else:
runtime_embedding_model = await self.ap.model_mgr.init_runtime_embedding_model(model_data)

View File

@@ -0,0 +1,152 @@
from __future__ import annotations
import uuid
import sqlalchemy
from ....core import app
from ....entity.persistence import model as persistence_model
class ModelProviderService:
"""Service for managing model providers"""
ap: app.Application
def __init__(self, ap: app.Application) -> None:
self.ap = ap
async def get_providers(self) -> list[dict]:
"""Get all providers"""
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.ModelProvider))
providers = result.all()
providers_list = []
for p in providers:
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, p)
# Parse api_keys if it's a JSON string
if isinstance(provider_dict.get('api_keys'), str):
import json
try:
provider_dict['api_keys'] = json.loads(provider_dict['api_keys'])
except Exception:
provider_dict['api_keys'] = []
providers_list.append(provider_dict)
return providers_list
async def get_provider(self, provider_uuid: str) -> dict | None:
"""Get a single provider by UUID"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider).where(
persistence_model.ModelProvider.uuid == provider_uuid
)
)
provider = result.first()
if provider is None:
return None
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
# Parse api_keys if it's a JSON string
if isinstance(provider_dict.get('api_keys'), str):
import json
try:
provider_dict['api_keys'] = json.loads(provider_dict['api_keys'])
except Exception:
provider_dict['api_keys'] = []
return provider_dict
async def create_provider(self, provider_data: dict) -> str:
"""Create a new provider"""
provider_data['uuid'] = str(uuid.uuid4())
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data)
)
return provider_data['uuid']
async def update_provider(self, provider_uuid: str, provider_data: dict) -> None:
"""Update an existing provider"""
if 'uuid' in provider_data:
del provider_data['uuid']
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.ModelProvider)
.where(persistence_model.ModelProvider.uuid == provider_uuid)
.values(**provider_data)
)
# Reload all models using this provider
await self.ap.model_mgr.load_models_from_db()
async def delete_provider(self, provider_uuid: str) -> None:
"""Delete a provider (only if no models reference it)"""
# Check if any models use this provider
llm_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(
persistence_model.LLMModel.provider_uuid == provider_uuid
)
)
if llm_result.first() is not None:
raise ValueError('Cannot delete provider: LLM models still reference it')
embedding_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.provider_uuid == provider_uuid
)
)
if embedding_result.first() is not None:
raise ValueError('Cannot delete provider: Embedding models still reference it')
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.ModelProvider).where(
persistence_model.ModelProvider.uuid == provider_uuid
)
)
async def get_provider_model_counts(self, provider_uuid: str) -> dict:
"""Get count of models using this provider"""
llm_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(sqlalchemy.func.count())
.select_from(persistence_model.LLMModel)
.where(persistence_model.LLMModel.provider_uuid == provider_uuid)
)
llm_count = llm_result.scalar() or 0
embedding_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(sqlalchemy.func.count())
.select_from(persistence_model.EmbeddingModel)
.where(persistence_model.EmbeddingModel.provider_uuid == provider_uuid)
)
embedding_count = embedding_result.scalar() or 0
return {'llm_count': llm_count, 'embedding_count': embedding_count}
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
"""Find existing provider or create new one"""
# Try to find existing provider with same config
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider).where(
persistence_model.ModelProvider.requester == requester,
persistence_model.ModelProvider.base_url == base_url,
)
)
for provider in result.all():
if sorted(provider.api_keys or []) == sorted(api_keys or []):
return provider.uuid
# Create new provider
provider_name = requester
if base_url:
try:
from urllib.parse import urlparse
parsed = urlparse(base_url)
provider_name = parsed.netloc or requester
except Exception:
pass
return await self.create_provider(
{
'name': provider_name,
'requester': requester,
'base_url': base_url,
'api_keys': api_keys or [],
}
)

View File

@@ -1,247 +0,0 @@
from __future__ import annotations
import typing
import uuid as uuid_lib
import aiohttp
import sqlalchemy
from ....core import app
from ....entity.persistence import model as persistence_model
from ....entity.persistence import user as persistence_user
DEFAULT_SPACE_URL = 'http://localhost:8383'
# Space's base URL for model API requests (used for requester_config)
SPACE_API_BASE_URL = 'http://localhost:8383'
class SpaceModelsService:
"""Service for syncing models from Space MaaS"""
ap: app.Application
def __init__(self, ap: app.Application) -> None:
self.ap = ap
async def get_space_user_info(self, user_email: str) -> persistence_user.User | None:
"""Get Space user info for sync operations"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_user.User).where(persistence_user.User.user == user_email)
)
result_list = result.all()
return result_list[0] if result_list else None
async def fetch_space_models(self, space_url: str = DEFAULT_SPACE_URL) -> typing.Dict:
"""Fetch available models from Space API"""
async with aiohttp.ClientSession() as session:
async with session.get(f'{space_url}/api/v1/models', params={'page_size': 100}) as response:
if response.status != 200:
raise ValueError(f'Failed to fetch models from Space: {await response.text()}')
data = await response.json()
if data.get('code') != 0:
raise ValueError(f'Failed to fetch models from Space: {data.get("msg")}')
return data.get('data', {})
async def sync_models_from_space(
self, user_email: str, space_url: str = DEFAULT_SPACE_URL
) -> typing.Dict[str, typing.Any]:
"""
Sync models from Space to local database.
Returns statistics about the sync operation.
"""
# Get user info for API key
user_obj = await self.get_space_user_info(user_email)
if user_obj is None:
raise ValueError('User not found')
if user_obj.account_type != 'space':
raise ValueError('User is not a Space account')
if not user_obj.space_api_key:
raise ValueError('User does not have a Space API key configured')
# Fetch models from Space
models_data = await self.fetch_space_models(space_url)
space_models = models_data.get('models', [])
# Get existing Space models in local database
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.source == 'space')
)
existing_space_models = {m.space_model_id: m for m in result.all()}
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.source == 'space'
)
)
existing_space_embedding_models = {m.space_model_id: m for m in result.all()}
stats = {'created_llm': 0, 'updated_llm': 0, 'created_embedding': 0, 'updated_embedding': 0, 'skipped': 0}
for model in space_models:
model_id = model.get('model_id')
category = model.get('category', '')
if not model_id:
stats['skipped'] += 1
continue
if category == 'embedding':
# Handle embedding model
await self._sync_embedding_model(model, user_obj.space_api_key, existing_space_embedding_models, stats)
else:
# Handle LLM model (chat, completion, etc.)
await self._sync_llm_model(model, user_obj.space_api_key, existing_space_models, stats)
return stats
async def _sync_llm_model(
self,
model: typing.Dict,
api_key: str,
existing_models: typing.Dict[str, persistence_model.LLMModel],
stats: typing.Dict,
) -> None:
"""Sync a single LLM model from Space"""
model_id = model.get('model_id')
display_name = model.get('display_name', {})
name = display_name.get('zh_Hans', display_name.get('en_US', model_id))
description_obj = model.get('description', {})
description = description_obj.get('zh_Hans', description_obj.get('en_US', '')) if description_obj else ''
# Infer abilities from model capabilities
abilities = []
supported_endpoints = model.get('supported_endpoints', [])
if 'vision' in str(supported_endpoints).lower() or 'vision' in model_id.lower():
abilities.append('vision')
if 'function' in str(supported_endpoints).lower() or 'tool' in str(supported_endpoints).lower():
abilities.append('function_call')
model_data = {
'name': name,
'description': description[:255] if description else 'Model from Space MaaS',
'requester': 'openai-chat-completions', # Space uses OpenAI-compatible API
'requester_config': {
'base-url': SPACE_API_BASE_URL,
'args': {},
'timeout': 120,
},
'api_keys': [api_key],
'abilities': abilities,
'extra_args': {'model': model_id},
'source': 'space',
'space_model_id': model_id,
}
if model_id in existing_models:
# Update existing model
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.LLMModel)
.where(persistence_model.LLMModel.space_model_id == model_id)
.values(**model_data)
)
stats['updated_llm'] += 1
else:
# Create new model
model_data['uuid'] = str(uuid_lib.uuid4())
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.LLMModel).values(**model_data)
)
stats['created_llm'] += 1
async def _sync_embedding_model(
self,
model: typing.Dict,
api_key: str,
existing_models: typing.Dict[str, persistence_model.EmbeddingModel],
stats: typing.Dict,
) -> None:
"""Sync a single embedding model from Space"""
model_id = model.get('model_id')
display_name = model.get('display_name', {})
name = display_name.get('zh_Hans', display_name.get('en_US', model_id))
description_obj = model.get('description', {})
description = description_obj.get('zh_Hans', description_obj.get('en_US', '')) if description_obj else ''
model_data = {
'name': name,
'description': description[:255] if description else 'Embedding model from Space MaaS',
'requester': 'openai-embedding', # Space uses OpenAI-compatible API
'requester_config': {
'base-url': SPACE_API_BASE_URL,
'args': {},
'timeout': 120,
},
'api_keys': [api_key],
'extra_args': {'model': model_id},
'source': 'space',
'space_model_id': model_id,
}
if model_id in existing_models:
# Update existing model
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.EmbeddingModel)
.where(persistence_model.EmbeddingModel.space_model_id == model_id)
.values(**model_data)
)
stats['updated_embedding'] += 1
else:
# Create new model
model_data['uuid'] = str(uuid_lib.uuid4())
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data)
)
stats['created_embedding'] += 1
async def get_space_models(self) -> typing.Dict[str, typing.List]:
"""Get all synced Space models"""
llm_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.source == 'space')
)
embedding_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.source == 'space'
)
)
return {
'llm_models': [
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, m) for m in llm_result.all()
],
'embedding_models': [
self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, m)
for m in embedding_result.all()
],
}
async def delete_space_models(self) -> typing.Dict[str, int]:
"""Delete all synced Space models"""
# Remove from model manager first
llm_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.source == 'space')
)
for model in llm_result.all():
await self.ap.model_mgr.remove_llm_model(model.uuid)
embedding_result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.source == 'space'
)
)
for model in embedding_result.all():
await self.ap.model_mgr.remove_embedding_model(model.uuid)
# Delete from database
llm_delete = await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.source == 'space')
)
embedding_delete = await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.EmbeddingModel).where(
persistence_model.EmbeddingModel.source == 'space'
)
)
return {'deleted_llm': llm_delete.rowcount, 'deleted_embedding': embedding_delete.rowcount}