Merge remote-tracking branch 'origin/upstream/main'

This commit is contained in:
Laisky.Cai
2024-03-15 09:49:49 +00:00
49 changed files with 623 additions and 205 deletions

View File

@@ -9,6 +9,7 @@ import (
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/env"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/driver/mysql"
@@ -18,8 +19,9 @@ import (
)
var DB *gorm.DB
var LOG_DB *gorm.DB
func createRootAccountIfNeed() error {
func CreateRootAccountIfNeed() error {
var user User
//if user.Status != util.UserStatusEnabled {
if err := DB.First(&user).Error; err != nil {
@@ -42,9 +44,9 @@ func createRootAccountIfNeed() error {
return nil
}
func chooseDB() (*gorm.DB, error) {
if os.Getenv("SQL_DSN") != "" {
dsn := os.Getenv("SQL_DSN")
func chooseDB(envName string) (*gorm.DB, error) {
if os.Getenv(envName) != "" {
dsn := os.Getenv(envName)
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
logger.SysLog("using PostgreSQL as database")
@@ -72,23 +74,22 @@ func chooseDB() (*gorm.DB, error) {
})
}
func InitDB() (err error) {
db, err := chooseDB()
func InitDB(envName string) (db *gorm.DB, err error) {
db, err = chooseDB(envName)
if err == nil {
if config.DebugSQLEnabled {
db = db.Debug()
}
DB = db
sqlDB, err := DB.DB()
sqlDB, err := db.DB()
if err != nil {
return errors.WithStack(err)
return nil, err
}
sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60)))
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
if !config.IsMasterNode {
return nil
return db, err
}
if common.UsingMySQL {
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
@@ -96,46 +97,55 @@ func InitDB() (err error) {
logger.SysLog("database migration started")
err = db.AutoMigrate(&Channel{})
if err != nil {
return errors.WithStack(err)
return nil, err
}
err = db.AutoMigrate(&Token{})
if err != nil {
return errors.WithStack(err)
return nil, err
}
err = db.AutoMigrate(&User{})
if err != nil {
return errors.WithStack(err)
return nil, err
}
err = db.AutoMigrate(&Option{})
if err != nil {
return errors.WithStack(err)
return nil, err
}
err = db.AutoMigrate(&Redemption{})
if err != nil {
return errors.WithStack(err)
return nil, err
}
err = db.AutoMigrate(&Ability{})
if err != nil {
return errors.WithStack(err)
return nil, err
}
err = db.AutoMigrate(&Log{})
if err != nil {
return errors.WithStack(err)
return nil, err
}
logger.SysLog("database migrated")
err = createRootAccountIfNeed()
return errors.WithStack(err)
return db, err
} else {
logger.FatalLog(err)
}
return errors.WithStack(err)
return db, err
}
func CloseDB() error {
sqlDB, err := DB.DB()
func closeDB(db *gorm.DB) error {
sqlDB, err := db.DB()
if err != nil {
return errors.WithStack(err)
}
err = sqlDB.Close()
return errors.WithStack(err)
}
func CloseDB() error {
if LOG_DB != DB {
err := closeDB(LOG_DB)
if err != nil {
return err
}
}
return closeDB(DB)
}