mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-17 05:33:42 +08:00
Merge remote-tracking branch 'upstream/main'
This commit is contained in:
@@ -17,35 +17,25 @@ type Ability struct {
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
|
||||
ability := Ability{}
|
||||
groupCol := "`group`"
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
trueVal = "true"
|
||||
}
|
||||
|
||||
var err error = nil
|
||||
|
||||
var cmdWhere *Ability
|
||||
cmdWhere := groupCol + " = ? and model = ? and enabled = " + trueVal
|
||||
|
||||
if stream {
|
||||
cmdWhere = &Ability{
|
||||
Group: group,
|
||||
Model: model,
|
||||
Enabled: true,
|
||||
AllowStreaming: common.ChannelAllowStreamEnabled,
|
||||
}
|
||||
cmdWhere += " and allow_streaming = 1"
|
||||
} else {
|
||||
cmdWhere = &Ability{
|
||||
Group: group,
|
||||
Model: model,
|
||||
Enabled: true,
|
||||
AllowNonStreaming: common.ChannelAllowNonStreamEnabled,
|
||||
}
|
||||
cmdWhere += " and allow_non_streaming = 1"
|
||||
}
|
||||
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(cmdWhere)
|
||||
|
||||
cmd1 := "`group` = ? and model = ? and enabled = 1 and priority = (?)"
|
||||
|
||||
if common.UsingPostgreSQL {
|
||||
cmd1 = "\"group\" = ? and model = ? and enabled = 1 and priority = (?)"
|
||||
}
|
||||
|
||||
channelQuery := DB.Where(cmd1, group, model, maxPrioritySubQuery)
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(cmdWhere, group, model)
|
||||
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
||||
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||
err = channelQuery.Order("RANDOM()").First(&ability).Error
|
||||
} else {
|
||||
|
||||
@@ -21,15 +21,19 @@ var (
|
||||
)
|
||||
|
||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||
keyCol := "`key`"
|
||||
if common.UsingPostgreSQL {
|
||||
keyCol = `"key"`
|
||||
}
|
||||
var token Token
|
||||
|
||||
if !common.RedisEnabled {
|
||||
err := DB.Where(&Token{Key: key}).First(&token).Error
|
||||
err := DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||
return &token, err
|
||||
}
|
||||
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
||||
if err != nil {
|
||||
err := DB.Where(&Token{Key: key}).First(&token).Error
|
||||
err := DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -59,19 +59,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||||
return &channel, err
|
||||
}
|
||||
|
||||
func GetRandomChannel() (*Channel, error) {
|
||||
channel := Channel{}
|
||||
var err error = nil
|
||||
if common.UsingPostgreSQL {
|
||||
err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Order("RANDOM()").Limit(1).First(&channel).Error
|
||||
} else if common.UsingSQLite {
|
||||
err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Order("RANDOM()").Limit(1).First(&channel).Error
|
||||
} else {
|
||||
err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Order("RAND()").Limit(1).First(&channel).Error
|
||||
}
|
||||
return &channel, err
|
||||
}
|
||||
|
||||
func BatchInsertChannels(channels []Channel) error {
|
||||
var err error
|
||||
err = DB.Create(&channels).Error
|
||||
|
||||
@@ -51,8 +51,13 @@ func Redeem(key string, userId int) (quota int, err error) {
|
||||
}
|
||||
redemption := &Redemption{}
|
||||
|
||||
keyCol := "`key`"
|
||||
if common.UsingPostgreSQL {
|
||||
keyCol = `"key"`
|
||||
}
|
||||
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(&Redemption{Key: key}).First(redemption).Error
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
|
||||
if err != nil {
|
||||
return errors.New("无效的兑换码")
|
||||
}
|
||||
|
||||
@@ -293,7 +293,12 @@ func GetUserEmail(id int) (email string, err error) {
|
||||
}
|
||||
|
||||
func GetUserGroup(id int) (group string, err error) {
|
||||
err = DB.Model(&User{}).Where(&User{Id: id}).Select("group").Find(&group).Error
|
||||
groupCol := "`group`"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
}
|
||||
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
|
||||
return group, err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user