mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-08 01:33:43 +08:00
Merge remote-tracking branch 'origin/upstream/main'
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/env"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"gorm.io/driver/mysql"
|
||||
@@ -18,8 +19,9 @@ import (
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
var LOG_DB *gorm.DB
|
||||
|
||||
func createRootAccountIfNeed() error {
|
||||
func CreateRootAccountIfNeed() error {
|
||||
var user User
|
||||
//if user.Status != util.UserStatusEnabled {
|
||||
if err := DB.First(&user).Error; err != nil {
|
||||
@@ -42,9 +44,9 @@ func createRootAccountIfNeed() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func chooseDB() (*gorm.DB, error) {
|
||||
if os.Getenv("SQL_DSN") != "" {
|
||||
dsn := os.Getenv("SQL_DSN")
|
||||
func chooseDB(envName string) (*gorm.DB, error) {
|
||||
if os.Getenv(envName) != "" {
|
||||
dsn := os.Getenv(envName)
|
||||
if strings.HasPrefix(dsn, "postgres://") {
|
||||
// Use PostgreSQL
|
||||
logger.SysLog("using PostgreSQL as database")
|
||||
@@ -72,23 +74,22 @@ func chooseDB() (*gorm.DB, error) {
|
||||
})
|
||||
}
|
||||
|
||||
func InitDB() (err error) {
|
||||
db, err := chooseDB()
|
||||
func InitDB(envName string) (db *gorm.DB, err error) {
|
||||
db, err = chooseDB(envName)
|
||||
if err == nil {
|
||||
if config.DebugSQLEnabled {
|
||||
db = db.Debug()
|
||||
}
|
||||
DB = db
|
||||
sqlDB, err := DB.DB()
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
return nil, err
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100))
|
||||
sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000))
|
||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60)))
|
||||
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
|
||||
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
|
||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
|
||||
|
||||
if !config.IsMasterNode {
|
||||
return nil
|
||||
return db, err
|
||||
}
|
||||
if common.UsingMySQL {
|
||||
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
|
||||
@@ -96,46 +97,55 @@ func InitDB() (err error) {
|
||||
logger.SysLog("database migration started")
|
||||
err = db.AutoMigrate(&Channel{})
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
return nil, err
|
||||
}
|
||||
err = db.AutoMigrate(&Token{})
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
return nil, err
|
||||
}
|
||||
err = db.AutoMigrate(&User{})
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
return nil, err
|
||||
}
|
||||
err = db.AutoMigrate(&Option{})
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
return nil, err
|
||||
}
|
||||
err = db.AutoMigrate(&Redemption{})
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
return nil, err
|
||||
}
|
||||
err = db.AutoMigrate(&Ability{})
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
return nil, err
|
||||
}
|
||||
err = db.AutoMigrate(&Log{})
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
return nil, err
|
||||
}
|
||||
logger.SysLog("database migrated")
|
||||
err = createRootAccountIfNeed()
|
||||
return errors.WithStack(err)
|
||||
return db, err
|
||||
} else {
|
||||
logger.FatalLog(err)
|
||||
}
|
||||
return errors.WithStack(err)
|
||||
return db, err
|
||||
}
|
||||
|
||||
func CloseDB() error {
|
||||
sqlDB, err := DB.DB()
|
||||
func closeDB(db *gorm.DB) error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
err = sqlDB.Close()
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
func CloseDB() error {
|
||||
if LOG_DB != DB {
|
||||
err := closeDB(LOG_DB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return closeDB(DB)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user