feat: database migration

This commit is contained in:
Junyan Qin
2025-04-13 20:50:13 +08:00
parent 854effc43e
commit edc7f81486
6 changed files with 123 additions and 6 deletions

View File

@@ -0,0 +1,19 @@
import sqlalchemy
from .base import Base
initial_metadata = [
{
'key': 'database_version',
'value': '0',
},
]
class Metadata(Base):
"""数据库元数据"""
__tablename__ = 'metadata'
key = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
value = sqlalchemy.Column(sqlalchemy.String(255))

View File

@@ -7,10 +7,12 @@ import typing
import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
import sqlalchemy
from . import database
from ..entity.persistence import base, user, model, pipeline, bot, plugin
from . import database, migration
from ..entity.persistence import base, user, model, pipeline, bot, plugin, metadata
from ..core import app
from .databases import sqlite
from ..utils import constants
from .migrations import dbm001_migrate_v3_config
class PersistenceManager:
@@ -36,14 +38,56 @@ class PersistenceManager:
await self.create_tables()
async def create_tables(self):
# TODO: 对扩展友好
# 日志
# create tables
async with self.get_db_engine().connect() as conn:
await conn.run_sync(self.meta.create_all)
await conn.commit()
# write initial metadata
for item in metadata.initial_metadata:
# check if the item exists
result = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == item['key'])
)
row = result.first()
if row is None:
await self.execute_async(
sqlalchemy.insert(metadata.Metadata).values(item)
)
# run migrations
database_version = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
)
database_version = int(database_version.fetchone()[1])
required_database_version = constants.required_database_version
if database_version < required_database_version:
migrations = migration.preregistered_db_migrations
migrations.sort(key=lambda x: x.number)
last_migration_number = database_version
for migration_cls in migrations:
migration_instance = migration_cls(self.ap)
if migration_instance.number > database_version and migration_instance.number <= required_database_version:
await migration_instance.upgrade()
await self.execute_async(
sqlalchemy.update(metadata.Metadata).where(metadata.Metadata.key == 'database_version').values(
{
'value': str(migration_instance.number)
}
)
)
last_migration_number = migration_instance.number
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
async def execute_async(
self,
*args,

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
import typing
import abc
from ..core import app
preregistered_db_migrations: list[typing.Type[DBMigration]] = []
def migration_class(number: int):
"""迁移类装饰器"""
def wrapper(cls: typing.Type[DBMigration]) -> typing.Type[DBMigration]:
cls.number = number
preregistered_db_migrations.append(cls)
return cls
return wrapper
class DBMigration(abc.ABC):
"""数据库迁移"""
number: int
"""迁移号"""
def __init__(self, ap: app.Application):
self.ap = ap
@abc.abstractmethod
async def upgrade(self):
"""升级"""
pass
@abc.abstractmethod
async def downgrade(self):
"""降级"""
pass

View File

View File

@@ -0,0 +1,13 @@
from .. import migration
# TODO fill this
# @migration.migration_class(1)
# class DBMigrationV3(migration.DBMigration):
# """数据库迁移"""
# async def upgrade(self):
# """升级"""
# pass
# async def downgrade(self):
# """降级"""

View File

@@ -1,4 +1,7 @@
semantic_version = "v3.4.11"
semantic_version = "v4.0.0"
required_database_version = 1
"""标记本版本所需要的数据库结构版本,用于判断数据库迁移"""
debug_mode = False