feat: enhance model creation with UUID preservation option and implement Space model synchronization in ModelManager

This commit is contained in:
Junyan Qin
2025-12-31 22:25:07 +08:00
parent 197258ae91
commit 96e40eaf25
11 changed files with 219 additions and 132 deletions

View File

@@ -64,9 +64,10 @@ class LLMModelsService:
models = result.all()
return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, m) for m in models]
async def create_llm_model(self, model_data: dict) -> str:
async def create_llm_model(self, model_data: dict, preserve_uuid: bool = False) -> str:
"""Create a new LLM model"""
model_data['uuid'] = str(uuid.uuid4())
if not preserve_uuid:
model_data['uuid'] = str(uuid.uuid4())
# Handle provider creation if needed
if 'provider' in model_data:
@@ -222,9 +223,10 @@ class EmbeddingModelsService:
models = result.all()
return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, m) for m in models]
async def create_embedding_model(self, model_data: dict) -> str:
async def create_embedding_model(self, model_data: dict, preserve_uuid: bool = False) -> str:
"""Create a new embedding model"""
model_data['uuid'] = str(uuid.uuid4())
if not preserve_uuid:
model_data['uuid'] = str(uuid.uuid4())
if 'provider' in model_data:
provider_data = model_data.pop('provider')

View File

@@ -8,6 +8,7 @@ import sqlalchemy
from ....core import app
from ....entity.persistence import user
from ....entity.dto.space_model import SpaceModel
class SpaceService:
@@ -170,3 +171,19 @@ class SpaceService:
return credits
except Exception:
return self._credits_cache.get(user_email, (None, 0))[0]
async def get_models(self) -> typing.List[SpaceModel]:
"""Get models from Space"""
space_config = self._get_space_config()
space_url = space_config['url']
async with aiohttp.ClientSession() as session:
async with session.get(f'{space_url}/api/v1/models') as response:
if response.status != 200:
raise ValueError(f'Failed to get models: {await response.text()}')
data = await response.json()
if data.get('code') != 0:
raise ValueError(f'Failed to get models: {data.get("msg")}')
models_data = data.get('data', {}).get('models', [])
return [SpaceModel.model_validate(model_dict) for model_dict in models_data]

View File

@@ -45,68 +45,6 @@ class BuildAppStage(stage.BootingStage):
discover.discover_blueprint('templates/components.yaml')
ap.discover = discover
proxy_mgr = proxy.ProxyManager(ap)
await proxy_mgr.initialize()
ap.proxy_mgr = proxy_mgr
ver_mgr = version.VersionManager(ap)
await ver_mgr.initialize()
ap.ver_mgr = ver_mgr
ap.query_pool = pool.QueryPool()
log_cache = logcache.LogCache()
ap.log_cache = log_cache
storage_mgr_inst = storagemgr.StorageMgr(ap)
await storage_mgr_inst.initialize()
ap.storage_mgr = storage_mgr_inst
persistence_mgr_inst = persistencemgr.PersistenceManager(ap)
ap.persistence_mgr = persistence_mgr_inst
await persistence_mgr_inst.initialize()
cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize()
ap.cmd_mgr = cmd_mgr_inst
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
await llm_model_mgr_inst.initialize()
ap.model_mgr = llm_model_mgr_inst
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
await llm_session_mgr_inst.initialize()
ap.sess_mgr = llm_session_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize()
ap.platform_mgr = im_mgr_inst
# Initialize webhook pusher
webhook_pusher_inst = WebhookPusher(ap)
ap.webhook_pusher = webhook_pusher_inst
pipeline_mgr = pipelinemgr.PipelineManager(ap)
await pipeline_mgr.initialize()
ap.pipeline_mgr = pipeline_mgr
rag_mgr_inst = rag_mgr.RAGManager(ap)
await rag_mgr_inst.initialize()
ap.rag_mgr = rag_mgr_inst
# 初始化向量数据库管理器
vectordb_mgr_inst = vectordb_mgr.VectorDBManager(ap)
await vectordb_mgr_inst.initialize()
ap.vector_db_mgr = vectordb_mgr_inst
http_ctrl = http_controller.HTTPController(ap)
await http_ctrl.initialize()
ap.http_ctrl = http_ctrl
user_service_inst = user_service.UserService(ap)
ap.user_service = user_service_inst
@@ -143,6 +81,68 @@ class BuildAppStage(stage.BootingStage):
webhook_service_inst = webhook_service.WebhookService(ap)
ap.webhook_service = webhook_service_inst
proxy_mgr = proxy.ProxyManager(ap)
await proxy_mgr.initialize()
ap.proxy_mgr = proxy_mgr
ver_mgr = version.VersionManager(ap)
await ver_mgr.initialize()
ap.ver_mgr = ver_mgr
ap.query_pool = pool.QueryPool()
log_cache = logcache.LogCache()
ap.log_cache = log_cache
storage_mgr_inst = storagemgr.StorageMgr(ap)
await storage_mgr_inst.initialize()
ap.storage_mgr = storage_mgr_inst
persistence_mgr_inst = persistencemgr.PersistenceManager(ap)
ap.persistence_mgr = persistence_mgr_inst
await persistence_mgr_inst.initialize()
cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize()
ap.cmd_mgr = cmd_mgr_inst
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
ap.model_mgr = llm_model_mgr_inst
await llm_model_mgr_inst.initialize()
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
await llm_session_mgr_inst.initialize()
ap.sess_mgr = llm_session_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize()
ap.platform_mgr = im_mgr_inst
# Initialize webhook pusher
webhook_pusher_inst = WebhookPusher(ap)
ap.webhook_pusher = webhook_pusher_inst
pipeline_mgr = pipelinemgr.PipelineManager(ap)
await pipeline_mgr.initialize()
ap.pipeline_mgr = pipeline_mgr
rag_mgr_inst = rag_mgr.RAGManager(ap)
await rag_mgr_inst.initialize()
ap.rag_mgr = rag_mgr_inst
# 初始化向量数据库管理器
vectordb_mgr_inst = vectordb_mgr.VectorDBManager(ap)
await vectordb_mgr_inst.initialize()
ap.vector_db_mgr = vectordb_mgr_inst
http_ctrl = http_controller.HTTPController(ap)
await http_ctrl.initialize()
ap.http_ctrl = http_ctrl
async def runtime_disconnect_callback(connector: plugin_connector.PluginRuntimeConnector) -> None:
await asyncio.sleep(3)
await plugin_connector_inst.initialize()

View File

View File

@@ -0,0 +1,49 @@
# [
# {
# "uuid": "7652ebdb-54dc-412c-a830-e9268ac88471",
# "model_id": "claude-opus-4-5-20251101",
# "display_name": {
# "en_US": "claude-opus-4-5-20251101",
# "zh_Hans": "claude-opus-4-5-20251101"
# },
# "description": {},
# "provider": "anthropic",
# "category": "chat",
# "icon_url": "Claude.Color",
# "tags": {},
# "is_featured": true,
# "featured_order": 999,
# "model_ratio": 2.5,
# "completion_ratio": 5,
# "quota_type": 0,
# "model_price": 0,
# "input_credits": 500,
# "output_credits": 2500,
# "vendor_id": 1,
# "vendor_name": "Anthropic",
# "vendor_icon": "Claude.Color",
# "supported_endpoints": [
# "anthropic",
# "openai"
# ],
# "status": "active",
# "metadata": null,
# "created_at": "2025-12-30T22:23:38.337207+08:00",
# "updated_at": "2025-12-30T22:23:38.337207+08:00"
# }
# ]
import pydantic
class SpaceModel(pydantic.BaseModel):
uuid: str
model_id: str
provider: str
category: str # chat / embedding
llm_abilities: list[str] | None = None
is_featured: bool = False
featured_order: int = 0
status: str
created_at: str | None = None
updated_at: str | None = None

View File

@@ -7,3 +7,11 @@ class RequesterNotFoundError(Exception):
def __str__(self):
return f'Requester {self.requester_name} not found'
class ProviderNotFoundError(Exception):
def __init__(self, provider_name: str):
self.provider_name = provider_name
def __str__(self):
return f'Provider {self.provider_name} not found'

View File

@@ -32,6 +32,7 @@ class LLMModel(Base):
provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
abilities = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[])
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
prefered_ranking = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
@@ -50,6 +51,7 @@ class EmbeddingModel(Base):
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
prefered_ranking = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,

View File

@@ -134,7 +134,7 @@ class PersistenceManager:
if result.first() is None:
self.ap.logger.info('Creating space model providers...')
space_chat_completions_model_provider = {
'uuid': str(uuid.uuid4()),
'uuid': '00000000-0000-0000-0000-000000000000',
'name': 'LangBot Models',
'requester': 'space-chat-completions',
'base_url': 'https://api.langbot.cloud/v1',

View File

@@ -1,77 +1,14 @@
import sqlalchemy
from .. import migration
# this is a deprecated migration
@migration.migration_class(15)
class DBMigrateModelSourceTracking(migration.DBMigration):
"""Add source tracking fields to models tables for Space integration"""
async def upgrade(self):
"""Upgrade"""
# Add source column to llm_models table
llm_columns = await self._get_columns('llm_models')
if 'source' not in llm_columns:
if self.ap.persistence_mgr.db.name == 'postgresql':
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text("ALTER TABLE llm_models ADD COLUMN source VARCHAR(32) DEFAULT 'local' NOT NULL")
)
else:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text("ALTER TABLE llm_models ADD COLUMN source VARCHAR(32) DEFAULT 'local' NOT NULL")
)
if 'space_model_id' not in llm_columns:
if self.ap.persistence_mgr.db.name == 'postgresql':
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('ALTER TABLE llm_models ADD COLUMN space_model_id VARCHAR(255)')
)
else:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('ALTER TABLE llm_models ADD COLUMN space_model_id VARCHAR(255)')
)
# Add source column to embedding_models table
embedding_columns = await self._get_columns('embedding_models')
if 'source' not in embedding_columns:
if self.ap.persistence_mgr.db.name == 'postgresql':
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text(
"ALTER TABLE embedding_models ADD COLUMN source VARCHAR(32) DEFAULT 'local' NOT NULL"
)
)
else:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text(
"ALTER TABLE embedding_models ADD COLUMN source VARCHAR(32) DEFAULT 'local' NOT NULL"
)
)
if 'space_model_id' not in embedding_columns:
if self.ap.persistence_mgr.db.name == 'postgresql':
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('ALTER TABLE embedding_models ADD COLUMN space_model_id VARCHAR(255)')
)
else:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('ALTER TABLE embedding_models ADD COLUMN space_model_id VARCHAR(255)')
)
async def _get_columns(self, table_name: str) -> list:
"""Get column names for a table"""
if self.ap.persistence_mgr.db.name == 'postgresql':
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.text(
f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table_name}';"
)
)
all_result = result.fetchall()
return [row[0] for row in all_result]
else:
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text(f'PRAGMA table_info({table_name});'))
all_result = result.fetchall()
return [row[1] for row in all_result]
pass
async def downgrade(self):
"""Downgrade"""

View File

@@ -61,6 +61,12 @@ class DBMigrateModelProviderRefactor(migration.DBMigration):
sqlalchemy.text('ALTER TABLE llm_models ADD COLUMN provider_uuid VARCHAR(255)')
)
# Add prefered_ranking column if not exists
if 'prefered_ranking' not in llm_columns:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('ALTER TABLE llm_models ADD COLUMN prefered_ranking INTEGER NOT NULL DEFAULT 0')
)
# Only migrate if old columns exist
if 'requester' not in llm_columns:
return
@@ -152,6 +158,12 @@ class DBMigrateModelProviderRefactor(migration.DBMigration):
sqlalchemy.text('ALTER TABLE embedding_models ADD COLUMN provider_uuid VARCHAR(255)')
)
# Add prefered_ranking column if not exists
if 'prefered_ranking' not in embedding_columns:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('ALTER TABLE embedding_models ADD COLUMN prefered_ranking INTEGER NOT NULL DEFAULT 0')
)
# Only migrate if old columns exist
if 'requester' not in embedding_columns:
return

View File

@@ -41,6 +41,7 @@ class ModelManager:
self.requester_dict = requester_dict
await self.load_models_from_db()
await self.sync_new_models_from_space()
async def load_models_from_db(self):
"""Load models from database"""
@@ -89,6 +90,65 @@ class ModelManager:
except Exception as e:
self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}')
async def sync_new_models_from_space(self):
"""Sync models from Space"""
space_model_provider = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider).where(
persistence_model.ModelProvider.requester == 'space-chat-completions'
)
)
result = space_model_provider.first()
if result is None:
raise provider_errors.ProviderNotFoundError('LangBot Models')
space_model_provider = result
# get the latest models from space
space_models = await self.ap.space_service.get_models()
exists_llm_models_uuids = [m['uuid'] for m in await self.ap.llm_model_service.get_llm_models()]
exists_embedding_models_uuids = [
m['uuid'] for m in await self.ap.embedding_models_service.get_embedding_models()
]
for space_model in space_models:
if space_model.category == 'chat':
uuid = space_model.uuid
if uuid in exists_llm_models_uuids:
continue
# model will be automatically loaded
await self.ap.llm_model_service.create_llm_model(
{
'uuid': space_model.uuid,
'name': space_model.model_id,
'provider_uuid': space_model_provider.uuid,
'abilities': space_model.llm_abilities or [],
'extra_args': {},
'prefered_ranking': space_model.featured_order,
},
preserve_uuid=True,
)
elif space_model.category == 'embedding':
uuid = space_model.uuid
if uuid in exists_embedding_models_uuids:
continue
# model will be automatically loaded
await self.ap.embedding_models_service.create_embedding_model(
{
'uuid': space_model.uuid,
'name': space_model.model_id,
'provider_uuid': space_model_provider.uuid,
'extra_args': {},
'prefered_ranking': space_model.featured_order,
},
preserve_uuid=True,
)
async def init_runtime_llm_model(
self,
model_info: dict,