mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-23 02:26:36 +08:00
feat: 加入渠道加权随机功能
This commit is contained in:
parent
1a8a24698f
commit
bdd611fd33
@ -168,6 +168,11 @@ func GetRandomString(length int) string {
|
|||||||
return string(key)
|
return string(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetRandomInt(max int) int {
|
||||||
|
//rand.Seed(time.Now().UnixNano())
|
||||||
|
return rand.Intn(max)
|
||||||
|
}
|
||||||
|
|
||||||
func GetTimestamp() int64 {
|
func GetTimestamp() int64 {
|
||||||
return time.Now().Unix()
|
return time.Now().Unix()
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ type Ability struct {
|
|||||||
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
|
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
|
||||||
|
Weight uint `json:"weight" gorm:"default:0;index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetGroupModels(group string) []string {
|
func GetGroupModels(group string) []string {
|
||||||
@ -25,7 +26,7 @@ func GetGroupModels(group string) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||||
ability := Ability{}
|
var abilities []Ability
|
||||||
groupCol := "`group`"
|
groupCol := "`group`"
|
||||||
trueVal := "1"
|
trueVal := "1"
|
||||||
if common.UsingPostgreSQL {
|
if common.UsingPostgreSQL {
|
||||||
@ -37,16 +38,39 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
|||||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
||||||
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
||||||
if common.UsingSQLite || common.UsingPostgreSQL {
|
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||||
err = channelQuery.Order("RANDOM()").First(&ability).Error
|
err = channelQuery.Order("weight DESC").Find(&abilities).Error
|
||||||
} else {
|
} else {
|
||||||
err = channelQuery.Order("RAND()").First(&ability).Error
|
err = channelQuery.Order("weight DESC").Find(&abilities).Error
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
channel := Channel{}
|
channel := Channel{}
|
||||||
channel.Id = ability.ChannelId
|
if len(abilities) > 0 {
|
||||||
err = DB.First(&channel, "id = ?", ability.ChannelId).Error
|
// Randomly choose one
|
||||||
|
weightSum := uint(0)
|
||||||
|
for _, ability_ := range abilities {
|
||||||
|
weightSum += ability_.Weight
|
||||||
|
}
|
||||||
|
if weightSum == 0 {
|
||||||
|
// All weight is 0, randomly choose one
|
||||||
|
channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
|
||||||
|
} else {
|
||||||
|
// Randomly choose one
|
||||||
|
weight := common.GetRandomInt(int(weightSum))
|
||||||
|
for _, ability_ := range abilities {
|
||||||
|
weight -= int(ability_.Weight)
|
||||||
|
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
|
||||||
|
if weight <= 0 {
|
||||||
|
channel.Id = ability_.ChannelId
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
err = DB.First(&channel, "id = ?", channel.Id).Error
|
||||||
return &channel, err
|
return &channel, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,6 +86,7 @@ func (channel *Channel) AddAbilities() error {
|
|||||||
ChannelId: channel.Id,
|
ChannelId: channel.Id,
|
||||||
Enabled: channel.Status == common.ChannelStatusEnabled,
|
Enabled: channel.Status == common.ChannelStatusEnabled,
|
||||||
Priority: channel.Priority,
|
Priority: channel.Priority,
|
||||||
|
Weight: uint(channel.GetWeight()),
|
||||||
}
|
}
|
||||||
abilities = append(abilities, ability)
|
abilities = append(abilities, ability)
|
||||||
}
|
}
|
||||||
|
@ -198,6 +198,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
|||||||
model = "gpt-4-gizmo-*"
|
model = "gpt-4-gizmo-*"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if memory cache is disabled, get channel directly from database
|
||||||
if !common.MemoryCacheEnabled {
|
if !common.MemoryCacheEnabled {
|
||||||
return GetRandomSatisfiedChannel(group, model)
|
return GetRandomSatisfiedChannel(group, model)
|
||||||
}
|
}
|
||||||
@ -218,8 +219,29 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
idx := rand.Intn(endIdx)
|
// Calculate the total weight of all channels up to endIdx
|
||||||
return channels[idx], nil
|
totalWeight := 0
|
||||||
|
for _, channel := range channels[:endIdx] {
|
||||||
|
totalWeight += channel.GetWeight()
|
||||||
|
}
|
||||||
|
|
||||||
|
if totalWeight == 0 {
|
||||||
|
// If all weights are 0, select a channel randomly
|
||||||
|
return channels[rand.Intn(endIdx)], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a random value in the range [0, totalWeight)
|
||||||
|
randomWeight := rand.Intn(totalWeight)
|
||||||
|
|
||||||
|
// Find a channel based on its weight
|
||||||
|
for _, channel := range channels[:endIdx] {
|
||||||
|
randomWeight -= channel.GetWeight()
|
||||||
|
if randomWeight <= 0 {
|
||||||
|
return channel, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// return the last channel if no channel is found
|
||||||
|
return channels[endIdx-1], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheGetChannel(id int) (*Channel, error) {
|
func CacheGetChannel(id int) (*Channel, error) {
|
||||||
|
@ -113,6 +113,13 @@ func (channel *Channel) GetPriority() int64 {
|
|||||||
return *channel.Priority
|
return *channel.Priority
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetWeight() int {
|
||||||
|
if channel.Weight == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return int(*channel.Weight)
|
||||||
|
}
|
||||||
|
|
||||||
func (channel *Channel) GetBaseURL() string {
|
func (channel *Channel) GetBaseURL() string {
|
||||||
if channel.BaseURL == nil {
|
if channel.BaseURL == nil {
|
||||||
return ""
|
return ""
|
||||||
|
@ -163,7 +163,7 @@ const ChannelsTable = () => {
|
|||||||
<div>
|
<div>
|
||||||
<InputNumber
|
<InputNumber
|
||||||
style={{width: 70}}
|
style={{width: 70}}
|
||||||
name='name'
|
name='priority'
|
||||||
onChange={value => {
|
onChange={value => {
|
||||||
manageChannel(record.id, 'priority', record, value);
|
manageChannel(record.id, 'priority', record, value);
|
||||||
}}
|
}}
|
||||||
@ -174,6 +174,25 @@ const ChannelsTable = () => {
|
|||||||
);
|
);
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
title: '权重',
|
||||||
|
dataIndex: 'weight',
|
||||||
|
render: (text, record, index) => {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<InputNumber
|
||||||
|
style={{width: 70}}
|
||||||
|
name='weight'
|
||||||
|
onChange={value => {
|
||||||
|
manageChannel(record.id, 'weight', record, value);
|
||||||
|
}}
|
||||||
|
defaultValue={record.weight}
|
||||||
|
min={0}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
title: '',
|
title: '',
|
||||||
dataIndex: 'operate',
|
dataIndex: 'operate',
|
||||||
|
Loading…
Reference in New Issue
Block a user