mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-18 06:03:42 +08:00
✨ feat: channel support weight (#85)
* ✨ feat: channel support weight * 💄 improve: show version * 💄 improve: Channel add copy operation * 💄 improve: Channel support batch add
This commit is contained in:
@@ -11,6 +11,7 @@ type Ability struct {
|
||||
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
|
||||
Weight *uint `json:"weight" gorm:"default:1"`
|
||||
}
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
@@ -67,6 +68,7 @@ func (channel *Channel) AddAbilities() error {
|
||||
ChannelId: channel.Id,
|
||||
Enabled: channel.Status == common.ChannelStatusEnabled,
|
||||
Priority: channel.Priority,
|
||||
Weight: channel.Weight,
|
||||
}
|
||||
abilities = append(abilities, ability)
|
||||
}
|
||||
@@ -98,3 +100,49 @@ func (channel *Channel) UpdateAbilities() error {
|
||||
func UpdateAbilityStatus(channelId int, status bool) error {
|
||||
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
|
||||
}
|
||||
|
||||
func GetEnabledAbility() ([]*Ability, error) {
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
trueVal = "true"
|
||||
}
|
||||
|
||||
var abilities []*Ability
|
||||
err := DB.Where("enabled = ?", trueVal).Order("priority desc, weight desc").Find(&abilities).Error
|
||||
return abilities, err
|
||||
}
|
||||
|
||||
type AbilityChannelGroup struct {
|
||||
Group string `json:"group"`
|
||||
Model string `json:"model"`
|
||||
Priority int `json:"priority"`
|
||||
ChannelIds string `json:"channel_ids"`
|
||||
}
|
||||
|
||||
func GetAbilityChannelGroup() ([]*AbilityChannelGroup, error) {
|
||||
var abilities []*AbilityChannelGroup
|
||||
|
||||
var channelSql string
|
||||
if common.UsingPostgreSQL {
|
||||
channelSql = `string_agg("channel_id"::text, ',')`
|
||||
} else if common.UsingSQLite {
|
||||
channelSql = `group_concat("channel_id", ',')`
|
||||
} else {
|
||||
channelSql = "GROUP_CONCAT(`channel_id` SEPARATOR ',')"
|
||||
}
|
||||
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
trueVal = "true"
|
||||
}
|
||||
|
||||
err := DB.Raw(`
|
||||
SELECT `+quotePostgresField("group")+`, model, priority, `+channelSql+` as channel_ids
|
||||
FROM abilities
|
||||
WHERE enabled = ?
|
||||
GROUP BY `+quotePostgresField("group")+`, model, priority
|
||||
ORDER BY priority DESC
|
||||
`, trueVal).Scan(&abilities).Error
|
||||
|
||||
return abilities, err
|
||||
}
|
||||
|
||||
176
model/balancer.go
Normal file
176
model/balancer.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChannelChoice struct {
|
||||
Channel *Channel
|
||||
CooldownsTime int64
|
||||
}
|
||||
|
||||
type ChannelsChooser struct {
|
||||
sync.RWMutex
|
||||
Channels map[int]*ChannelChoice
|
||||
Rule map[string]map[string][][]int // group -> model -> priority -> channelIds
|
||||
}
|
||||
|
||||
func (cc *ChannelsChooser) Cooldowns(channelId int) bool {
|
||||
if common.RetryCooldownSeconds == 0 {
|
||||
return false
|
||||
}
|
||||
cc.Lock()
|
||||
defer cc.Unlock()
|
||||
if _, ok := cc.Channels[channelId]; !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
cc.Channels[channelId].CooldownsTime = time.Now().Unix() + int64(common.RetryCooldownSeconds)
|
||||
return true
|
||||
}
|
||||
|
||||
func (cc *ChannelsChooser) Balancer(channelIds []int) *Channel {
|
||||
nowTime := time.Now().Unix()
|
||||
totalWeight := 0
|
||||
|
||||
validChannels := make([]*ChannelChoice, 0, len(channelIds))
|
||||
for _, channelId := range channelIds {
|
||||
if choice, ok := cc.Channels[channelId]; ok && choice.CooldownsTime < nowTime {
|
||||
weight := int(*choice.Channel.Weight)
|
||||
totalWeight += weight
|
||||
validChannels = append(validChannels, choice)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validChannels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(validChannels) == 1 {
|
||||
return validChannels[0].Channel
|
||||
}
|
||||
|
||||
choiceWeight := rand.Intn(totalWeight)
|
||||
for _, choice := range validChannels {
|
||||
weight := int(*choice.Channel.Weight)
|
||||
choiceWeight -= weight
|
||||
if choiceWeight < 0 {
|
||||
return choice.Channel
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cc *ChannelsChooser) Next(group, model string) (*Channel, error) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetRandomSatisfiedChannel(group, model)
|
||||
}
|
||||
cc.RLock()
|
||||
defer cc.RUnlock()
|
||||
if _, ok := cc.Rule[group]; !ok {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
|
||||
if _, ok := cc.Rule[group][model]; !ok {
|
||||
return nil, errors.New("model not found")
|
||||
}
|
||||
|
||||
channelsPriority := cc.Rule[group][model]
|
||||
if len(channelsPriority) == 0 {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
|
||||
for _, priority := range channelsPriority {
|
||||
channel := cc.Balancer(priority)
|
||||
if channel != nil {
|
||||
return channel, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
|
||||
func (cc *ChannelsChooser) GetGroupModels(group string) ([]string, error) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetGroupModels(group)
|
||||
}
|
||||
|
||||
cc.RLock()
|
||||
defer cc.RUnlock()
|
||||
|
||||
if _, ok := cc.Rule[group]; !ok {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(cc.Rule[group]))
|
||||
for model := range cc.Rule[group] {
|
||||
models = append(models, model)
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
var ChannelGroup = ChannelsChooser{}
|
||||
|
||||
func InitChannelGroup() {
|
||||
var channels []*Channel
|
||||
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
|
||||
|
||||
abilities, err := GetAbilityChannelGroup()
|
||||
if err != nil {
|
||||
common.SysLog("get enabled abilities failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
newGroup := make(map[string]map[string][][]int)
|
||||
newChannels := make(map[int]*ChannelChoice)
|
||||
|
||||
for _, channel := range channels {
|
||||
if *channel.Weight == 0 {
|
||||
channel.Weight = &common.DefaultChannelWeight
|
||||
}
|
||||
newChannels[channel.Id] = &ChannelChoice{
|
||||
Channel: channel,
|
||||
CooldownsTime: 0,
|
||||
}
|
||||
}
|
||||
|
||||
for _, ability := range abilities {
|
||||
if _, ok := newGroup[ability.Group]; !ok {
|
||||
newGroup[ability.Group] = make(map[string][][]int)
|
||||
}
|
||||
|
||||
if _, ok := newGroup[ability.Group][ability.Model]; !ok {
|
||||
newGroup[ability.Group][ability.Model] = make([][]int, 0)
|
||||
}
|
||||
|
||||
var priorityIds []int
|
||||
// 逗号分割 ability.ChannelId
|
||||
channelIds := strings.Split(ability.ChannelIds, ",")
|
||||
for _, channelId := range channelIds {
|
||||
priorityIds = append(priorityIds, common.String2Int(channelId))
|
||||
}
|
||||
|
||||
newGroup[ability.Group][ability.Model] = append(newGroup[ability.Group][ability.Model], priorityIds)
|
||||
}
|
||||
|
||||
ChannelGroup.Lock()
|
||||
ChannelGroup.Rule = newGroup
|
||||
ChannelGroup.Channels = newChannels
|
||||
ChannelGroup.Unlock()
|
||||
common.SysLog("channels synced from database")
|
||||
}
|
||||
|
||||
func SyncChannelGroup(frequency int) {
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Second)
|
||||
common.SysLog("syncing channels from database")
|
||||
InitChannelGroup()
|
||||
}
|
||||
}
|
||||
106
model/cache.go
106
model/cache.go
@@ -2,14 +2,9 @@ package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -131,104 +126,3 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
}
|
||||
return userEnabled, err
|
||||
}
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
var channelSyncLock sync.RWMutex
|
||||
|
||||
func InitChannelCache() {
|
||||
newChannelId2channel := make(map[int]*Channel)
|
||||
var channels []*Channel
|
||||
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
|
||||
for _, channel := range channels {
|
||||
newChannelId2channel[channel.Id] = channel
|
||||
}
|
||||
var abilities []*Ability
|
||||
DB.Find(&abilities)
|
||||
groups := make(map[string]bool)
|
||||
for _, ability := range abilities {
|
||||
groups[ability.Group] = true
|
||||
}
|
||||
newGroup2model2channels := make(map[string]map[string][]*Channel)
|
||||
for group := range groups {
|
||||
newGroup2model2channels[group] = make(map[string][]*Channel)
|
||||
}
|
||||
for _, channel := range channels {
|
||||
groups := strings.Split(channel.Group, ",")
|
||||
for _, group := range groups {
|
||||
models := strings.Split(channel.Models, ",")
|
||||
for _, model := range models {
|
||||
if _, ok := newGroup2model2channels[group][model]; !ok {
|
||||
newGroup2model2channels[group][model] = make([]*Channel, 0)
|
||||
}
|
||||
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sort by priority
|
||||
for group, model2channels := range newGroup2model2channels {
|
||||
for model, channels := range model2channels {
|
||||
sort.Slice(channels, func(i, j int) bool {
|
||||
return channels[i].GetPriority() > channels[j].GetPriority()
|
||||
})
|
||||
newGroup2model2channels[group][model] = channels
|
||||
}
|
||||
}
|
||||
|
||||
channelSyncLock.Lock()
|
||||
group2model2channels = newGroup2model2channels
|
||||
channelSyncLock.Unlock()
|
||||
common.SysLog("channels synced from database")
|
||||
}
|
||||
|
||||
func SyncChannelCache(frequency int) {
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Second)
|
||||
common.SysLog("syncing channels from database")
|
||||
InitChannelCache()
|
||||
}
|
||||
}
|
||||
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetRandomSatisfiedChannel(group, model)
|
||||
}
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
channels := group2model2channels[group][model]
|
||||
if len(channels) == 0 {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
endIdx := len(channels)
|
||||
// choose by priority
|
||||
firstChannel := channels[0]
|
||||
if firstChannel.GetPriority() > 0 {
|
||||
for i := range channels {
|
||||
if channels[i].GetPriority() != firstChannel.GetPriority() {
|
||||
endIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
idx := rand.Intn(endIdx)
|
||||
return channels[idx], nil
|
||||
}
|
||||
|
||||
func CacheGetGroupModels(group string) ([]string, error) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetGroupModels(group)
|
||||
}
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
|
||||
groupModels := group2model2channels[group]
|
||||
if groupModels == nil {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
|
||||
models := make([]string, 0)
|
||||
for model := range groupModels {
|
||||
models = append(models, model)
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ type Channel struct {
|
||||
Key string `json:"key" form:"key" gorm:"type:varchar(767);not null;index"`
|
||||
Status int `json:"status" form:"status" gorm:"default:1"`
|
||||
Name string `json:"name" form:"name" gorm:"index"`
|
||||
Weight *uint `json:"weight" gorm:"default:0"`
|
||||
Weight *uint `json:"weight" gorm:"default:1"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
TestTime int64 `json:"test_time" gorm:"bigint"`
|
||||
ResponseTime int `json:"response_time"` // in milliseconds
|
||||
@@ -95,11 +95,8 @@ func GetAllChannels() ([]*Channel, error) {
|
||||
func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||||
channel := Channel{Id: id}
|
||||
var err error = nil
|
||||
if selectAll {
|
||||
err = DB.First(&channel, "id = ?", id).Error
|
||||
} else {
|
||||
err = DB.Omit("key").First(&channel, "id = ?", id).Error
|
||||
}
|
||||
err = DB.First(&channel, "id = ?", id).Error
|
||||
|
||||
return &channel, err
|
||||
}
|
||||
|
||||
|
||||
@@ -77,6 +77,8 @@ func InitOptionMap() {
|
||||
common.OptionMap["ChatLink"] = common.ChatLink
|
||||
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
|
||||
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
|
||||
common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds)
|
||||
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
initModelRatio()
|
||||
loadOptionsFromDatabase()
|
||||
@@ -146,6 +148,7 @@ var optionIntMap = map[string]*int{
|
||||
"QuotaRemindThreshold": &common.QuotaRemindThreshold,
|
||||
"PreConsumedQuota": &common.PreConsumedQuota,
|
||||
"RetryTimes": &common.RetryTimes,
|
||||
"RetryCooldownSeconds": &common.RetryCooldownSeconds,
|
||||
}
|
||||
|
||||
var optionBoolMap = map[string]*bool{
|
||||
|
||||
Reference in New Issue
Block a user