diff --git a/model/ability.go b/model/ability.go index f522967..01fea9e 100644 --- a/model/ability.go +++ b/model/ability.go @@ -3,6 +3,7 @@ package model import ( "errors" "fmt" + "gorm.io/gorm" "one-api/common" "strings" ) @@ -27,8 +28,7 @@ func GetGroupModels(group string) []string { return models } -func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { - var abilities []Ability +func getPriority(group string, model string, retry int) (int, error) { groupCol := "`group`" trueVal := "1" if common.UsingPostgreSQL { @@ -36,9 +36,55 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { trueVal = "true" } - var err error = nil + var priorities []int + err := DB.Model(&Ability{}). + Select("DISTINCT(priority)"). + Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model). + Order("priority DESC"). // 按优先级降序排序 + Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中 + + if err != nil { + // 处理错误 + return 0, err + } + + // 确定要使用的优先级 + var priorityToUse int + if retry >= len(priorities) { + // 如果重试次数大于优先级数,则使用最小的优先级 + priorityToUse = priorities[len(priorities)-1] + } else { + priorityToUse = priorities[retry] + } + return priorityToUse, nil +} + +func getChannelQuery(group string, model string, retry int) *gorm.DB { + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) + if retry != 0 { + priority, err := getPriority(group, model, retry) + if err != nil { + common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error())) + } else { + channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority) + } + } + + return channelQuery +} + +func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { + var abilities []Ability + + var err error = nil + channelQuery := getChannelQuery(group, model, retry) if common.UsingSQLite || common.UsingPostgreSQL { err = channelQuery.Order("weight DESC").Find(&abilities).Error } else { diff --git a/model/cache.go b/model/cache.go index 78bdc17..dc2ed3b 100644 --- a/model/cache.go +++ b/model/cache.go @@ -272,7 +272,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha // if memory cache is disabled, get channel directly from database if !common.MemoryCacheEnabled { - return GetRandomSatisfiedChannel(group, model) + return GetRandomSatisfiedChannel(group, model, retry) } channelSyncLock.RLock() defer channelSyncLock.RUnlock()