mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-07 22:53:41 +08:00
feat: 加入渠道加权随机功能
This commit is contained in:
@@ -198,6 +198,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
||||
model = "gpt-4-gizmo-*"
|
||||
}
|
||||
|
||||
// if memory cache is disabled, get channel directly from database
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetRandomSatisfiedChannel(group, model)
|
||||
}
|
||||
@@ -218,8 +219,29 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
||||
}
|
||||
}
|
||||
}
|
||||
idx := rand.Intn(endIdx)
|
||||
return channels[idx], nil
|
||||
// Calculate the total weight of all channels up to endIdx
|
||||
totalWeight := 0
|
||||
for _, channel := range channels[:endIdx] {
|
||||
totalWeight += channel.GetWeight()
|
||||
}
|
||||
|
||||
if totalWeight == 0 {
|
||||
// If all weights are 0, select a channel randomly
|
||||
return channels[rand.Intn(endIdx)], nil
|
||||
}
|
||||
|
||||
// Generate a random value in the range [0, totalWeight)
|
||||
randomWeight := rand.Intn(totalWeight)
|
||||
|
||||
// Find a channel based on its weight
|
||||
for _, channel := range channels[:endIdx] {
|
||||
randomWeight -= channel.GetWeight()
|
||||
if randomWeight <= 0 {
|
||||
return channel, nil
|
||||
}
|
||||
}
|
||||
// return the last channel if no channel is found
|
||||
return channels[endIdx-1], nil
|
||||
}
|
||||
|
||||
func CacheGetChannel(id int) (*Channel, error) {
|
||||
|
||||
Reference in New Issue
Block a user