from __future__ import annotations import datetime import typing import sqlalchemy.ext.asyncio as sqlalchemy_asyncio import sqlalchemy from . import database, migration from ..entity.persistence import base, metadata, model as persistence_model from ..entity import persistence from ..core import app from ..utils import constants, importutil from . import databases, migrations importutil.import_modules_in_pkg(databases) importutil.import_modules_in_pkg(migrations) importutil.import_modules_in_pkg(persistence) class PersistenceManager: """Persistence module manager""" ap: app.Application db: database.BaseDatabaseManager """Database manager""" meta: sqlalchemy.MetaData def __init__(self, ap: app.Application): self.ap = ap self.meta = base.Base.metadata async def initialize(self): database_type = self.ap.instance_config.data.get('database', {}).get('use', 'sqlite') self.ap.logger.info(f'Initializing database type: {database_type}...') for manager in database.preregistered_managers: if manager.name == database_type: self.db = manager(self.ap) await self.db.initialize() break await self.create_tables() # run migrations database_version = await self.execute_async( sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version') ) database_version = int(database_version.fetchone()[1]) required_database_version = constants.required_database_version if database_version < required_database_version: migrations = migration.preregistered_db_migrations migrations.sort(key=lambda x: x.number) last_migration_number = database_version for migration_cls in migrations: migration_instance = migration_cls(self.ap) if ( migration_instance.number > database_version and migration_instance.number <= required_database_version ): await migration_instance.upgrade() await self.execute_async( sqlalchemy.update(metadata.Metadata) .where(metadata.Metadata.key == 'database_version') .values({'value': str(migration_instance.number)}) ) last_migration_number = migration_instance.number self.ap.logger.info(f'Migration {migration_instance.number} completed.') self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.') await self.write_space_model_providers() async def create_tables(self): # create tables async with self.get_db_engine().connect() as conn: await conn.run_sync(self.meta.create_all) await conn.commit() # ======= write initial data ======= # write initial metadata self.ap.logger.info('Creating initial metadata...') for item in metadata.initial_metadata: # check if the item exists result = await self.execute_async( sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == item['key']) ) row = result.first() if row is None: await self.execute_async(sqlalchemy.insert(metadata.Metadata).values(item)) async def write_space_model_providers(self): space_models_gateway_api_url = self.ap.instance_config.data.get('space', {}).get( 'models_gateway_api_url', 'https://api.langbot.cloud/v1' ) # write space model providers result = await self.execute_async( sqlalchemy.select(persistence_model.ModelProvider).where( persistence_model.ModelProvider.requester == 'space-chat-completions' ) ) exists_space_chat_completions_model_provider = result.first() # api keys will be set/updated when the oauth callback if exists_space_chat_completions_model_provider is None: self.ap.logger.info('Creating space model providers...') space_chat_completions_model_provider = { 'uuid': '00000000-0000-0000-0000-000000000000', 'name': 'LangBot Models', 'requester': 'space-chat-completions', 'base_url': space_models_gateway_api_url, 'api_keys': [], } await self.execute_async( sqlalchemy.insert(persistence_model.ModelProvider).values(space_chat_completions_model_provider) ) else: if exists_space_chat_completions_model_provider.base_url != space_models_gateway_api_url: await self.execute_async( sqlalchemy.update(persistence_model.ModelProvider) .where(persistence_model.ModelProvider.uuid == exists_space_chat_completions_model_provider.uuid) .values({'base_url': space_models_gateway_api_url}) ) # ================================= async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult: async with self.get_db_engine().connect() as conn: result = await conn.execute(*args, **kwargs) await conn.commit() return result def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine: return self.db.get_engine() def serialize_model( self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base, masked_columns: list[str] = [] ) -> dict: return { column.name: getattr(data, column.name) if not isinstance(getattr(data, column.name), (datetime.datetime)) else getattr(data, column.name).isoformat() for column in model.__table__.columns if column.name not in masked_columns }