feat: 加入渠道加权随机功能

This commit is contained in:
CaIon 2023-12-27 19:00:47 +08:00
parent 1a8a24698f
commit bdd611fd33
5 changed files with 86 additions and 8 deletions

View File

@ -168,6 +168,11 @@ func GetRandomString(length int) string {
return string(key)
}
func GetRandomInt(max int) int {
//rand.Seed(time.Now().UnixNano())
return rand.Intn(max)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}

View File

@ -11,6 +11,7 @@ type Ability struct {
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
Weight uint `json:"weight" gorm:"default:0;index"`
}
func GetGroupModels(group string) []string {
@ -25,7 +26,7 @@ func GetGroupModels(group string) []string {
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{}
var abilities []Ability
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
@ -37,16 +38,39 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error
err = channelQuery.Order("weight DESC").Find(&abilities).Error
} else {
err = channelQuery.Order("RAND()").First(&ability).Error
err = channelQuery.Order("weight DESC").Find(&abilities).Error
}
if err != nil {
return nil, err
}
channel := Channel{}
channel.Id = ability.ChannelId
err = DB.First(&channel, "id = ?", ability.ChannelId).Error
if len(abilities) > 0 {
// Randomly choose one
weightSum := uint(0)
for _, ability_ := range abilities {
weightSum += ability_.Weight
}
if weightSum == 0 {
// All weight is 0, randomly choose one
channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
} else {
// Randomly choose one
weight := common.GetRandomInt(int(weightSum))
for _, ability_ := range abilities {
weight -= int(ability_.Weight)
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
if weight <= 0 {
channel.Id = ability_.ChannelId
break
}
}
}
} else {
return nil, nil
}
err = DB.First(&channel, "id = ?", channel.Id).Error
return &channel, err
}
@ -62,6 +86,7 @@ func (channel *Channel) AddAbilities() error {
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
Weight: uint(channel.GetWeight()),
}
abilities = append(abilities, ability)
}

View File

@ -198,6 +198,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
model = "gpt-4-gizmo-*"
}
// if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model)
}
@ -218,8 +219,29 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
}
}
}
idx := rand.Intn(endIdx)
return channels[idx], nil
// Calculate the total weight of all channels up to endIdx
totalWeight := 0
for _, channel := range channels[:endIdx] {
totalWeight += channel.GetWeight()
}
if totalWeight == 0 {
// If all weights are 0, select a channel randomly
return channels[rand.Intn(endIdx)], nil
}
// Generate a random value in the range [0, totalWeight)
randomWeight := rand.Intn(totalWeight)
// Find a channel based on its weight
for _, channel := range channels[:endIdx] {
randomWeight -= channel.GetWeight()
if randomWeight <= 0 {
return channel, nil
}
}
// return the last channel if no channel is found
return channels[endIdx-1], nil
}
func CacheGetChannel(id int) (*Channel, error) {

View File

@ -113,6 +113,13 @@ func (channel *Channel) GetPriority() int64 {
return *channel.Priority
}
func (channel *Channel) GetWeight() int {
if channel.Weight == nil {
return 0
}
return int(*channel.Weight)
}
func (channel *Channel) GetBaseURL() string {
if channel.BaseURL == nil {
return ""

View File

@ -163,7 +163,7 @@ const ChannelsTable = () => {
<div>
<InputNumber
style={{width: 70}}
name='name'
name='priority'
onChange={value => {
manageChannel(record.id, 'priority', record, value);
}}
@ -174,6 +174,25 @@ const ChannelsTable = () => {
);
},
},
{
title: '权重',
dataIndex: 'weight',
render: (text, record, index) => {
return (
<div>
<InputNumber
style={{width: 70}}
name='weight'
onChange={value => {
manageChannel(record.id, 'weight', record, value);
}}
defaultValue={record.weight}
min={0}
/>
</div>
);
},
},
{
title: '',
dataIndex: 'operate',