Files
LangBot/pkg/persistence/mgr.py
2025-04-28 13:54:12 +08:00

141 lines
4.9 KiB
Python

from __future__ import annotations
import asyncio
import datetime
import typing
import json
import uuid
import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
import sqlalchemy
from . import database, migration
from ..entity.persistence import base, user, model, pipeline, bot, plugin, metadata
from ..core import app
from .databases import sqlite
from ..utils import constants
from .migrations import dbm001_migrate_v3_config
from ..api.http.service import pipeline as pipeline_service
class PersistenceManager:
"""持久化模块管理器"""
ap: app.Application
db: database.BaseDatabaseManager
"""数据库管理器"""
meta: sqlalchemy.MetaData
def __init__(self, ap: app.Application):
self.ap = ap
self.meta = base.Base.metadata
async def initialize(self):
self.ap.logger.info("Initializing database...")
for manager in database.preregistered_managers:
self.db = manager(self.ap)
await self.db.initialize()
await self.create_tables()
async def create_tables(self):
# create tables
async with self.get_db_engine().connect() as conn:
await conn.run_sync(self.meta.create_all)
await conn.commit()
# ======= write initial data =======
# write initial metadata
self.ap.logger.info("Creating initial metadata...")
for item in metadata.initial_metadata:
# check if the item exists
result = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == item['key'])
)
row = result.first()
if row is None:
await self.execute_async(
sqlalchemy.insert(metadata.Metadata).values(item)
)
# write default pipeline
result = await self.execute_async(
sqlalchemy.select(pipeline.LegacyPipeline)
)
if result.first() is None:
self.ap.logger.info("Creating default pipeline...")
pipeline_config = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8'))
pipeline_data = {
'uuid': str(uuid.uuid4()),
'for_version': self.ap.ver_mgr.get_current_version(),
'stages': pipeline_service.default_stage_order,
'is_default': True,
'name': 'Chat Pipeline',
'description': '默认对话配置流水线',
'config': pipeline_config,
}
await self.execute_async(
sqlalchemy.insert(pipeline.LegacyPipeline).values(pipeline_data)
)
# =================================
# run migrations
database_version = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
)
database_version = int(database_version.fetchone()[1])
required_database_version = constants.required_database_version
if database_version < required_database_version:
migrations = migration.preregistered_db_migrations
migrations.sort(key=lambda x: x.number)
last_migration_number = database_version
for migration_cls in migrations:
migration_instance = migration_cls(self.ap)
if migration_instance.number > database_version and migration_instance.number <= required_database_version:
await migration_instance.upgrade()
await self.execute_async(
sqlalchemy.update(metadata.Metadata).where(metadata.Metadata.key == 'database_version').values(
{
'value': str(migration_instance.number)
}
)
)
last_migration_number = migration_instance.number
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
async def execute_async(
self,
*args,
**kwargs
) -> sqlalchemy.engine.cursor.CursorResult:
async with self.get_db_engine().connect() as conn:
result = await conn.execute(*args, **kwargs)
await conn.commit()
return result
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
return self.db.get_engine()
def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict:
return {
column.name: getattr(data, column.name) if not isinstance(getattr(data, column.name), (datetime.datetime)) else getattr(data, column.name).isoformat()
for column in model.__table__.columns
}