mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat: database migration
This commit is contained in:
19
pkg/entity/persistence/metadata.py
Normal file
19
pkg/entity/persistence/metadata.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import sqlalchemy
|
||||
|
||||
from .base import Base
|
||||
|
||||
|
||||
initial_metadata = [
|
||||
{
|
||||
'key': 'database_version',
|
||||
'value': '0',
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class Metadata(Base):
|
||||
"""数据库元数据"""
|
||||
__tablename__ = 'metadata'
|
||||
|
||||
key = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
|
||||
value = sqlalchemy.Column(sqlalchemy.String(255))
|
||||
@@ -7,10 +7,12 @@ import typing
|
||||
import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
|
||||
import sqlalchemy
|
||||
|
||||
from . import database
|
||||
from ..entity.persistence import base, user, model, pipeline, bot, plugin
|
||||
from . import database, migration
|
||||
from ..entity.persistence import base, user, model, pipeline, bot, plugin, metadata
|
||||
from ..core import app
|
||||
from .databases import sqlite
|
||||
from ..utils import constants
|
||||
from .migrations import dbm001_migrate_v3_config
|
||||
|
||||
|
||||
class PersistenceManager:
|
||||
@@ -36,14 +38,56 @@ class PersistenceManager:
|
||||
await self.create_tables()
|
||||
|
||||
async def create_tables(self):
|
||||
# TODO: 对扩展友好
|
||||
|
||||
# 日志
|
||||
|
||||
# create tables
|
||||
async with self.get_db_engine().connect() as conn:
|
||||
await conn.run_sync(self.meta.create_all)
|
||||
|
||||
await conn.commit()
|
||||
|
||||
# write initial metadata
|
||||
for item in metadata.initial_metadata:
|
||||
# check if the item exists
|
||||
result = await self.execute_async(
|
||||
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == item['key'])
|
||||
)
|
||||
row = result.first()
|
||||
if row is None:
|
||||
await self.execute_async(
|
||||
sqlalchemy.insert(metadata.Metadata).values(item)
|
||||
)
|
||||
|
||||
# run migrations
|
||||
database_version = await self.execute_async(
|
||||
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
|
||||
)
|
||||
|
||||
database_version = int(database_version.fetchone()[1])
|
||||
required_database_version = constants.required_database_version
|
||||
|
||||
if database_version < required_database_version:
|
||||
migrations = migration.preregistered_db_migrations
|
||||
migrations.sort(key=lambda x: x.number)
|
||||
|
||||
last_migration_number = database_version
|
||||
|
||||
for migration_cls in migrations:
|
||||
migration_instance = migration_cls(self.ap)
|
||||
|
||||
if migration_instance.number > database_version and migration_instance.number <= required_database_version:
|
||||
await migration_instance.upgrade()
|
||||
await self.execute_async(
|
||||
sqlalchemy.update(metadata.Metadata).where(metadata.Metadata.key == 'database_version').values(
|
||||
{
|
||||
'value': str(migration_instance.number)
|
||||
}
|
||||
)
|
||||
)
|
||||
last_migration_number = migration_instance.number
|
||||
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
|
||||
|
||||
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
|
||||
|
||||
async def execute_async(
|
||||
self,
|
||||
*args,
|
||||
|
||||
38
pkg/persistence/migration.py
Normal file
38
pkg/persistence/migration.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import abc
|
||||
|
||||
from ..core import app
|
||||
|
||||
|
||||
preregistered_db_migrations: list[typing.Type[DBMigration]] = []
|
||||
|
||||
def migration_class(number: int):
|
||||
"""迁移类装饰器"""
|
||||
|
||||
def wrapper(cls: typing.Type[DBMigration]) -> typing.Type[DBMigration]:
|
||||
cls.number = number
|
||||
preregistered_db_migrations.append(cls)
|
||||
return cls
|
||||
return wrapper
|
||||
|
||||
|
||||
class DBMigration(abc.ABC):
|
||||
"""数据库迁移"""
|
||||
|
||||
number: int
|
||||
"""迁移号"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
@abc.abstractmethod
|
||||
async def upgrade(self):
|
||||
"""升级"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def downgrade(self):
|
||||
"""降级"""
|
||||
pass
|
||||
0
pkg/persistence/migrations/__init__.py
Normal file
0
pkg/persistence/migrations/__init__.py
Normal file
13
pkg/persistence/migrations/dbm001_migrate_v3_config.py
Normal file
13
pkg/persistence/migrations/dbm001_migrate_v3_config.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .. import migration
|
||||
|
||||
# TODO fill this
|
||||
# @migration.migration_class(1)
|
||||
# class DBMigrationV3(migration.DBMigration):
|
||||
# """数据库迁移"""
|
||||
|
||||
# async def upgrade(self):
|
||||
# """升级"""
|
||||
# pass
|
||||
|
||||
# async def downgrade(self):
|
||||
# """降级"""
|
||||
@@ -1,4 +1,7 @@
|
||||
semantic_version = "v3.4.11"
|
||||
semantic_version = "v4.0.0"
|
||||
|
||||
required_database_version = 1
|
||||
"""标记本版本所需要的数据库结构版本,用于判断数据库迁移"""
|
||||
|
||||
debug_mode = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user