From 863b26c3faf88c13c9a0ca51b398b35fd399e4c4 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Mon, 29 Dec 2025 20:42:06 +0800 Subject: [PATCH] refactor: update column drop logic in DBMigrateModelProviderRefactor for PostgreSQL compatibility --- .../dbm016_model_provider_refactor.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/langbot/pkg/persistence/migrations/dbm016_model_provider_refactor.py b/src/langbot/pkg/persistence/migrations/dbm016_model_provider_refactor.py index 88438409..286967c3 100644 --- a/src/langbot/pkg/persistence/migrations/dbm016_model_provider_refactor.py +++ b/src/langbot/pkg/persistence/migrations/dbm016_model_provider_refactor.py @@ -247,9 +247,14 @@ class DBMigrateModelProviderRefactor(migration.DBMigration): deprecated_llm_cols = ['requester', 'requester_config', 'api_keys', 'description', 'source', 'space_model_id'] for col in deprecated_llm_cols: if col in llm_columns: - await self.ap.persistence_mgr.execute_async( - sqlalchemy.text(f'ALTER TABLE llm_models DROP COLUMN IF EXISTS {col}') - ) + if self.ap.persistence_mgr.db.name == 'postgresql': + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text(f'ALTER TABLE llm_models DROP COLUMN IF EXISTS {col}') + ) + else: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text(f'ALTER TABLE llm_models DROP COLUMN {col}') + ) embedding_columns = await self._get_columns('embedding_models') deprecated_embedding_cols = [ @@ -262,9 +267,14 @@ class DBMigrateModelProviderRefactor(migration.DBMigration): ] for col in deprecated_embedding_cols: if col in embedding_columns: - await self.ap.persistence_mgr.execute_async( - sqlalchemy.text(f'ALTER TABLE embedding_models DROP COLUMN IF EXISTS {col}') - ) + if self.ap.persistence_mgr.db.name == 'postgresql': + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text(f'ALTER TABLE embedding_models DROP COLUMN IF EXISTS {col}') + ) + else: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text(f'ALTER TABLE embedding_models DROP COLUMN {col}') + ) async def _get_columns(self, table_name: str) -> list: """Get column names for a table"""