mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-16 21:23:44 +08:00
Merge branch 'support-postgres-sql' into refactor-main
This commit is contained in:
@@ -21,12 +21,23 @@ func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channe
|
||||
|
||||
cmd := "`group` = ? and model = ? and enabled = 1"
|
||||
|
||||
if common.UsingPostgreSQL {
|
||||
// Make cmd compatible with PostgreSQL
|
||||
cmd = "\"group\" = ? and model = ? and enabled = true"
|
||||
}
|
||||
|
||||
if stream {
|
||||
cmd += fmt.Sprintf(" and allow_streaming = %d", common.ChannelAllowStreamEnabled)
|
||||
} else {
|
||||
cmd += fmt.Sprintf(" and allow_non_streaming = %d", common.ChannelAllowNonStreamEnabled)
|
||||
}
|
||||
|
||||
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||
err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error
|
||||
} else {
|
||||
cmd += fmt.Sprintf(" and allow_non_streaming = %d", common.ChannelAllowNonStreamEnabled)
|
||||
}
|
||||
|
||||
if common.UsingSQLite {
|
||||
err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error
|
||||
} else {
|
||||
|
||||
@@ -21,13 +21,19 @@ var (
|
||||
|
||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||
var token Token
|
||||
whereItem := "`key` = ?"
|
||||
if common.UsingPostgreSQL {
|
||||
// Make cmd compatible with PostgreSQL
|
||||
whereItem = "\"key\" = ?"
|
||||
}
|
||||
|
||||
if !common.RedisEnabled {
|
||||
err := DB.Where("key = ?", key).First(&token).Error
|
||||
err := DB.Where(whereItem, key).First(&token).Error
|
||||
return &token, err
|
||||
}
|
||||
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
||||
if err != nil {
|
||||
err := DB.Where("key = ?", key).First(&token).Error
|
||||
err := DB.Where(whereItem, key).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -40,7 +40,13 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||
}
|
||||
|
||||
func SearchChannels(keyword string) (channels []*Channel, err error) {
|
||||
err = DB.Omit("key").Where("id = ? or name LIKE ? or key = ?", keyword, keyword+"%", keyword).Find(&channels).Error
|
||||
whereItem := "id = ? or name LIKE ? or `key` = ?"
|
||||
|
||||
if common.UsingPostgreSQL {
|
||||
whereItem = "id = ? or name LIKE ? or \"key\" = ?"
|
||||
}
|
||||
|
||||
err = DB.Omit("key").Where(whereItem, keyword, keyword+"%", keyword).Find(&channels).Error
|
||||
return channels, err
|
||||
}
|
||||
|
||||
@@ -58,7 +64,9 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||||
func GetRandomChannel() (*Channel, error) {
|
||||
channel := Channel{}
|
||||
var err error = nil
|
||||
if common.UsingSQLite {
|
||||
if common.UsingPostgreSQL {
|
||||
err = DB.Where("status = ? and \"group\" = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
|
||||
} else if common.UsingSQLite {
|
||||
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
|
||||
} else {
|
||||
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
|
||||
|
||||
@@ -45,6 +45,7 @@ func InitDB() (err error) {
|
||||
if os.Getenv("POSTGRES_DSN") != "" {
|
||||
// Use PostgreSQL
|
||||
common.SysLog("using PostgreSQL as database")
|
||||
common.UsingPostgreSQL = true
|
||||
db, err = gorm.Open(postgres.Open(os.Getenv("POSTGRES_DSN")), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
|
||||
@@ -52,7 +52,14 @@ func Redeem(key string, userId int) (quota int, err error) {
|
||||
redemption := &Redemption{}
|
||||
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("key = ?", key).First(redemption).Error
|
||||
whereItem := "`key` = ?"
|
||||
|
||||
if common.UsingPostgreSQL {
|
||||
// Make cmd compatible with PostgreSQL
|
||||
whereItem = "\"key\" = ?"
|
||||
}
|
||||
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(whereItem, key).First(redemption).Error
|
||||
if err != nil {
|
||||
return errors.New("无效的兑换码")
|
||||
}
|
||||
|
||||
@@ -294,7 +294,13 @@ func GetUserEmail(id int) (email string, err error) {
|
||||
}
|
||||
|
||||
func GetUserGroup(id int) (group string, err error) {
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
|
||||
selectItem := "`group`"
|
||||
|
||||
if common.UsingPostgreSQL {
|
||||
selectItem = "\"group\""
|
||||
}
|
||||
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select(selectItem).Find(&group).Error
|
||||
return group, err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user