From edc7f81486c26fa0b9ff3249b6e790ffc09e5eb6 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sun, 13 Apr 2025 20:50:13 +0800 Subject: [PATCH] feat: database migration --- pkg/entity/persistence/metadata.py | 19 +++++++ pkg/persistence/mgr.py | 54 +++++++++++++++++-- pkg/persistence/migration.py | 38 +++++++++++++ pkg/persistence/migrations/__init__.py | 0 .../migrations/dbm001_migrate_v3_config.py | 13 +++++ pkg/utils/constants.py | 5 +- 6 files changed, 123 insertions(+), 6 deletions(-) create mode 100644 pkg/entity/persistence/metadata.py create mode 100644 pkg/persistence/migration.py create mode 100644 pkg/persistence/migrations/__init__.py create mode 100644 pkg/persistence/migrations/dbm001_migrate_v3_config.py diff --git a/pkg/entity/persistence/metadata.py b/pkg/entity/persistence/metadata.py new file mode 100644 index 00000000..e1ebaefd --- /dev/null +++ b/pkg/entity/persistence/metadata.py @@ -0,0 +1,19 @@ +import sqlalchemy + +from .base import Base + + +initial_metadata = [ + { + 'key': 'database_version', + 'value': '0', + }, +] + + +class Metadata(Base): + """数据库元数据""" + __tablename__ = 'metadata' + + key = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True) + value = sqlalchemy.Column(sqlalchemy.String(255)) diff --git a/pkg/persistence/mgr.py b/pkg/persistence/mgr.py index 56809e6b..dec273db 100644 --- a/pkg/persistence/mgr.py +++ b/pkg/persistence/mgr.py @@ -7,10 +7,12 @@ import typing import sqlalchemy.ext.asyncio as sqlalchemy_asyncio import sqlalchemy -from . import database -from ..entity.persistence import base, user, model, pipeline, bot, plugin +from . import database, migration +from ..entity.persistence import base, user, model, pipeline, bot, plugin, metadata from ..core import app from .databases import sqlite +from ..utils import constants +from .migrations import dbm001_migrate_v3_config class PersistenceManager: @@ -36,14 +38,56 @@ class PersistenceManager: await self.create_tables() async def create_tables(self): - # TODO: 对扩展友好 - - # 日志 + + # create tables async with self.get_db_engine().connect() as conn: await conn.run_sync(self.meta.create_all) await conn.commit() + # write initial metadata + for item in metadata.initial_metadata: + # check if the item exists + result = await self.execute_async( + sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == item['key']) + ) + row = result.first() + if row is None: + await self.execute_async( + sqlalchemy.insert(metadata.Metadata).values(item) + ) + + # run migrations + database_version = await self.execute_async( + sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version') + ) + + database_version = int(database_version.fetchone()[1]) + required_database_version = constants.required_database_version + + if database_version < required_database_version: + migrations = migration.preregistered_db_migrations + migrations.sort(key=lambda x: x.number) + + last_migration_number = database_version + + for migration_cls in migrations: + migration_instance = migration_cls(self.ap) + + if migration_instance.number > database_version and migration_instance.number <= required_database_version: + await migration_instance.upgrade() + await self.execute_async( + sqlalchemy.update(metadata.Metadata).where(metadata.Metadata.key == 'database_version').values( + { + 'value': str(migration_instance.number) + } + ) + ) + last_migration_number = migration_instance.number + self.ap.logger.info(f'Migration {migration_instance.number} completed.') + + self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.') + async def execute_async( self, *args, diff --git a/pkg/persistence/migration.py b/pkg/persistence/migration.py new file mode 100644 index 00000000..81a3aac3 --- /dev/null +++ b/pkg/persistence/migration.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import typing +import abc + +from ..core import app + + +preregistered_db_migrations: list[typing.Type[DBMigration]] = [] + +def migration_class(number: int): + """迁移类装饰器""" + + def wrapper(cls: typing.Type[DBMigration]) -> typing.Type[DBMigration]: + cls.number = number + preregistered_db_migrations.append(cls) + return cls + return wrapper + + +class DBMigration(abc.ABC): + """数据库迁移""" + + number: int + """迁移号""" + + def __init__(self, ap: app.Application): + self.ap = ap + + @abc.abstractmethod + async def upgrade(self): + """升级""" + pass + + @abc.abstractmethod + async def downgrade(self): + """降级""" + pass diff --git a/pkg/persistence/migrations/__init__.py b/pkg/persistence/migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/persistence/migrations/dbm001_migrate_v3_config.py b/pkg/persistence/migrations/dbm001_migrate_v3_config.py new file mode 100644 index 00000000..afed5eea --- /dev/null +++ b/pkg/persistence/migrations/dbm001_migrate_v3_config.py @@ -0,0 +1,13 @@ +from .. import migration + +# TODO fill this +# @migration.migration_class(1) +# class DBMigrationV3(migration.DBMigration): +# """数据库迁移""" + +# async def upgrade(self): +# """升级""" +# pass + +# async def downgrade(self): +# """降级""" \ No newline at end of file diff --git a/pkg/utils/constants.py b/pkg/utils/constants.py index 14b2b74c..f37f2151 100644 --- a/pkg/utils/constants.py +++ b/pkg/utils/constants.py @@ -1,4 +1,7 @@ -semantic_version = "v3.4.11" +semantic_version = "v4.0.0" + +required_database_version = 1 +"""标记本版本所需要的数据库结构版本,用于判断数据库迁移""" debug_mode = False