feat: database migration

This commit is contained in:
Junyan Qin
2025-04-13 20:50:13 +08:00
parent 854effc43e
commit edc7f81486
6 changed files with 123 additions and 6 deletions
+49 -5
View File
@@ -7,10 +7,12 @@ import typing
import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
import sqlalchemy
from . import database
from ..entity.persistence import base, user, model, pipeline, bot, plugin
from . import database, migration
from ..entity.persistence import base, user, model, pipeline, bot, plugin, metadata
from ..core import app
from .databases import sqlite
from ..utils import constants
from .migrations import dbm001_migrate_v3_config
class PersistenceManager:
@@ -36,14 +38,56 @@ class PersistenceManager:
await self.create_tables()
async def create_tables(self):
# TODO: 对扩展友好
# 日志
# create tables
async with self.get_db_engine().connect() as conn:
await conn.run_sync(self.meta.create_all)
await conn.commit()
# write initial metadata
for item in metadata.initial_metadata:
# check if the item exists
result = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == item['key'])
)
row = result.first()
if row is None:
await self.execute_async(
sqlalchemy.insert(metadata.Metadata).values(item)
)
# run migrations
database_version = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
)
database_version = int(database_version.fetchone()[1])
required_database_version = constants.required_database_version
if database_version < required_database_version:
migrations = migration.preregistered_db_migrations
migrations.sort(key=lambda x: x.number)
last_migration_number = database_version
for migration_cls in migrations:
migration_instance = migration_cls(self.ap)
if migration_instance.number > database_version and migration_instance.number <= required_database_version:
await migration_instance.upgrade()
await self.execute_async(
sqlalchemy.update(metadata.Metadata).where(metadata.Metadata.key == 'database_version').values(
{
'value': str(migration_instance.number)
}
)
)
last_migration_number = migration_instance.number
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
async def execute_async(
self,
*args,