feat: channel support weight (#85)

*  feat: channel support weight

* 💄 improve: show version

* 💄 improve: Channel add copy operation

* 💄 improve: Channel support batch add
This commit is contained in:
Buer
2024-03-06 18:01:43 +08:00
committed by GitHub
parent 7c78ed9fad
commit dd3e79a20d
44 changed files with 1425 additions and 1019 deletions

View File

@@ -11,6 +11,7 @@ type Ability struct {
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
Weight *uint `json:"weight" gorm:"default:1"`
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
@@ -67,6 +68,7 @@ func (channel *Channel) AddAbilities() error {
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
Weight: channel.Weight,
}
abilities = append(abilities, ability)
}
@@ -98,3 +100,49 @@ func (channel *Channel) UpdateAbilities() error {
func UpdateAbilityStatus(channelId int, status bool) error {
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
}
func GetEnabledAbility() ([]*Ability, error) {
trueVal := "1"
if common.UsingPostgreSQL {
trueVal = "true"
}
var abilities []*Ability
err := DB.Where("enabled = ?", trueVal).Order("priority desc, weight desc").Find(&abilities).Error
return abilities, err
}
type AbilityChannelGroup struct {
Group string `json:"group"`
Model string `json:"model"`
Priority int `json:"priority"`
ChannelIds string `json:"channel_ids"`
}
func GetAbilityChannelGroup() ([]*AbilityChannelGroup, error) {
var abilities []*AbilityChannelGroup
var channelSql string
if common.UsingPostgreSQL {
channelSql = `string_agg("channel_id"::text, ',')`
} else if common.UsingSQLite {
channelSql = `group_concat("channel_id", ',')`
} else {
channelSql = "GROUP_CONCAT(`channel_id` SEPARATOR ',')"
}
trueVal := "1"
if common.UsingPostgreSQL {
trueVal = "true"
}
err := DB.Raw(`
SELECT `+quotePostgresField("group")+`, model, priority, `+channelSql+` as channel_ids
FROM abilities
WHERE enabled = ?
GROUP BY `+quotePostgresField("group")+`, model, priority
ORDER BY priority DESC
`, trueVal).Scan(&abilities).Error
return abilities, err
}

176
model/balancer.go Normal file
View File

@@ -0,0 +1,176 @@
package model
import (
"errors"
"math/rand"
"one-api/common"
"strings"
"sync"
"time"
)
type ChannelChoice struct {
Channel *Channel
CooldownsTime int64
}
type ChannelsChooser struct {
sync.RWMutex
Channels map[int]*ChannelChoice
Rule map[string]map[string][][]int // group -> model -> priority -> channelIds
}
func (cc *ChannelsChooser) Cooldowns(channelId int) bool {
if common.RetryCooldownSeconds == 0 {
return false
}
cc.Lock()
defer cc.Unlock()
if _, ok := cc.Channels[channelId]; !ok {
return false
}
cc.Channels[channelId].CooldownsTime = time.Now().Unix() + int64(common.RetryCooldownSeconds)
return true
}
func (cc *ChannelsChooser) Balancer(channelIds []int) *Channel {
nowTime := time.Now().Unix()
totalWeight := 0
validChannels := make([]*ChannelChoice, 0, len(channelIds))
for _, channelId := range channelIds {
if choice, ok := cc.Channels[channelId]; ok && choice.CooldownsTime < nowTime {
weight := int(*choice.Channel.Weight)
totalWeight += weight
validChannels = append(validChannels, choice)
}
}
if len(validChannels) == 0 {
return nil
}
if len(validChannels) == 1 {
return validChannels[0].Channel
}
choiceWeight := rand.Intn(totalWeight)
for _, choice := range validChannels {
weight := int(*choice.Channel.Weight)
choiceWeight -= weight
if choiceWeight < 0 {
return choice.Channel
}
}
return nil
}
func (cc *ChannelsChooser) Next(group, model string) (*Channel, error) {
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model)
}
cc.RLock()
defer cc.RUnlock()
if _, ok := cc.Rule[group]; !ok {
return nil, errors.New("group not found")
}
if _, ok := cc.Rule[group][model]; !ok {
return nil, errors.New("model not found")
}
channelsPriority := cc.Rule[group][model]
if len(channelsPriority) == 0 {
return nil, errors.New("channel not found")
}
for _, priority := range channelsPriority {
channel := cc.Balancer(priority)
if channel != nil {
return channel, nil
}
}
return nil, errors.New("channel not found")
}
func (cc *ChannelsChooser) GetGroupModels(group string) ([]string, error) {
if !common.MemoryCacheEnabled {
return GetGroupModels(group)
}
cc.RLock()
defer cc.RUnlock()
if _, ok := cc.Rule[group]; !ok {
return nil, errors.New("group not found")
}
models := make([]string, 0, len(cc.Rule[group]))
for model := range cc.Rule[group] {
models = append(models, model)
}
return models, nil
}
var ChannelGroup = ChannelsChooser{}
func InitChannelGroup() {
var channels []*Channel
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
abilities, err := GetAbilityChannelGroup()
if err != nil {
common.SysLog("get enabled abilities failed: " + err.Error())
return
}
newGroup := make(map[string]map[string][][]int)
newChannels := make(map[int]*ChannelChoice)
for _, channel := range channels {
if *channel.Weight == 0 {
channel.Weight = &common.DefaultChannelWeight
}
newChannels[channel.Id] = &ChannelChoice{
Channel: channel,
CooldownsTime: 0,
}
}
for _, ability := range abilities {
if _, ok := newGroup[ability.Group]; !ok {
newGroup[ability.Group] = make(map[string][][]int)
}
if _, ok := newGroup[ability.Group][ability.Model]; !ok {
newGroup[ability.Group][ability.Model] = make([][]int, 0)
}
var priorityIds []int
// 逗号分割 ability.ChannelId
channelIds := strings.Split(ability.ChannelIds, ",")
for _, channelId := range channelIds {
priorityIds = append(priorityIds, common.String2Int(channelId))
}
newGroup[ability.Group][ability.Model] = append(newGroup[ability.Group][ability.Model], priorityIds)
}
ChannelGroup.Lock()
ChannelGroup.Rule = newGroup
ChannelGroup.Channels = newChannels
ChannelGroup.Unlock()
common.SysLog("channels synced from database")
}
func SyncChannelGroup(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing channels from database")
InitChannelGroup()
}
}

View File

@@ -2,14 +2,9 @@ package model
import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"one-api/common"
"sort"
"strconv"
"strings"
"sync"
"time"
)
@@ -131,104 +126,3 @@ func CacheIsUserEnabled(userId int) (bool, error) {
}
return userEnabled, err
}
var group2model2channels map[string]map[string][]*Channel
var channelSyncLock sync.RWMutex
func InitChannelCache() {
newChannelId2channel := make(map[int]*Channel)
var channels []*Channel
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
for _, channel := range channels {
newChannelId2channel[channel.Id] = channel
}
var abilities []*Ability
DB.Find(&abilities)
groups := make(map[string]bool)
for _, ability := range abilities {
groups[ability.Group] = true
}
newGroup2model2channels := make(map[string]map[string][]*Channel)
for group := range groups {
newGroup2model2channels[group] = make(map[string][]*Channel)
}
for _, channel := range channels {
groups := strings.Split(channel.Group, ",")
for _, group := range groups {
models := strings.Split(channel.Models, ",")
for _, model := range models {
if _, ok := newGroup2model2channels[group][model]; !ok {
newGroup2model2channels[group][model] = make([]*Channel, 0)
}
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
}
}
}
// sort by priority
for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return channels[i].GetPriority() > channels[j].GetPriority()
})
newGroup2model2channels[group][model] = channels
}
}
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelSyncLock.Unlock()
common.SysLog("channels synced from database")
}
func SyncChannelCache(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing channels from database")
InitChannelCache()
}
}
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
channels := group2model2channels[group][model]
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
endIdx := len(channels)
// choose by priority
firstChannel := channels[0]
if firstChannel.GetPriority() > 0 {
for i := range channels {
if channels[i].GetPriority() != firstChannel.GetPriority() {
endIdx = i
break
}
}
}
idx := rand.Intn(endIdx)
return channels[idx], nil
}
func CacheGetGroupModels(group string) ([]string, error) {
if !common.MemoryCacheEnabled {
return GetGroupModels(group)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
groupModels := group2model2channels[group]
if groupModels == nil {
return nil, errors.New("group not found")
}
models := make([]string, 0)
for model := range groupModels {
models = append(models, model)
}
return models, nil
}

View File

@@ -13,7 +13,7 @@ type Channel struct {
Key string `json:"key" form:"key" gorm:"type:varchar(767);not null;index"`
Status int `json:"status" form:"status" gorm:"default:1"`
Name string `json:"name" form:"name" gorm:"index"`
Weight *uint `json:"weight" gorm:"default:0"`
Weight *uint `json:"weight" gorm:"default:1"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
@@ -95,11 +95,8 @@ func GetAllChannels() ([]*Channel, error) {
func GetChannelById(id int, selectAll bool) (*Channel, error) {
channel := Channel{Id: id}
var err error = nil
if selectAll {
err = DB.First(&channel, "id = ?", id).Error
} else {
err = DB.Omit("key").First(&channel, "id = ?", id).Error
}
err = DB.First(&channel, "id = ?", id).Error
return &channel, err
}

View File

@@ -77,6 +77,8 @@ func InitOptionMap() {
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds)
common.OptionMapRWMutex.Unlock()
initModelRatio()
loadOptionsFromDatabase()
@@ -146,6 +148,7 @@ var optionIntMap = map[string]*int{
"QuotaRemindThreshold": &common.QuotaRemindThreshold,
"PreConsumedQuota": &common.PreConsumedQuota,
"RetryTimes": &common.RetryTimes,
"RetryCooldownSeconds": &common.RetryCooldownSeconds,
}
var optionBoolMap = map[string]*bool{