Merge remote-tracking branch 'upstream/main'

This commit is contained in:
ckt1031
2023-10-25 19:56:30 +08:00
parent d79a7b5902
commit e1d840e7dd
16 changed files with 80 additions and 47 deletions

View File

@@ -17,35 +17,25 @@ type Ability struct {
func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
ability := Ability{}
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
var err error = nil
var cmdWhere *Ability
cmdWhere := groupCol + " = ? and model = ? and enabled = " + trueVal
if stream {
cmdWhere = &Ability{
Group: group,
Model: model,
Enabled: true,
AllowStreaming: common.ChannelAllowStreamEnabled,
}
cmdWhere += " and allow_streaming = 1"
} else {
cmdWhere = &Ability{
Group: group,
Model: model,
Enabled: true,
AllowNonStreaming: common.ChannelAllowNonStreamEnabled,
}
cmdWhere += " and allow_non_streaming = 1"
}
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(cmdWhere)
cmd1 := "`group` = ? and model = ? and enabled = 1 and priority = (?)"
if common.UsingPostgreSQL {
cmd1 = "\"group\" = ? and model = ? and enabled = 1 and priority = (?)"
}
channelQuery := DB.Where(cmd1, group, model, maxPrioritySubQuery)
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(cmdWhere, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error
} else {

View File

@@ -21,15 +21,19 @@ var (
)
func CacheGetTokenByKey(key string) (*Token, error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
var token Token
if !common.RedisEnabled {
err := DB.Where(&Token{Key: key}).First(&token).Error
err := DB.Where(keyCol+" = ?", key).First(&token).Error
return &token, err
}
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
err := DB.Where(&Token{Key: key}).First(&token).Error
err := DB.Where(keyCol+" = ?", key).First(&token).Error
if err != nil {
return nil, err
}

View File

@@ -59,19 +59,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
return &channel, err
}
func GetRandomChannel() (*Channel, error) {
channel := Channel{}
var err error = nil
if common.UsingPostgreSQL {
err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Order("RANDOM()").Limit(1).First(&channel).Error
} else if common.UsingSQLite {
err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Order("RANDOM()").Limit(1).First(&channel).Error
} else {
err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Order("RAND()").Limit(1).First(&channel).Error
}
return &channel, err
}
func BatchInsertChannels(channels []Channel) error {
var err error
err = DB.Create(&channels).Error

View File

@@ -51,8 +51,13 @@ func Redeem(key string, userId int) (quota int, err error) {
}
redemption := &Redemption{}
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(&Redemption{Key: key}).First(redemption).Error
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
if err != nil {
return errors.New("无效的兑换码")
}

View File

@@ -293,7 +293,12 @@ func GetUserEmail(id int) (email string, err error) {
}
func GetUserGroup(id int) (group string, err error) {
err = DB.Model(&User{}).Where(&User{Id: id}).Select("group").Find(&group).Error
groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
}
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
return group, err
}