feat: add weight_mapping for channel

This commit is contained in:
WqyJh
2023-11-01 14:58:04 +08:00
parent aec343dc38
commit 4a27d75d57
11 changed files with 153 additions and 37 deletions

View File

@@ -2,7 +2,6 @@ package model
import (
"one-api/common"
"strings"
)
type Ability struct {
@@ -40,8 +39,8 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
}
func (channel *Channel) AddAbilities() error {
models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
models_ := channel.GetModels()
groups_ := channel.GetGroups()
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {

View File

@@ -4,13 +4,12 @@ import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"one-api/common"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/mroth/weightedrand/v2"
)
var (
@@ -132,7 +131,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err
}
var group2model2channels map[string]map[string][]*Channel
var group2model2channels map[string]map[string]*weightedrand.Chooser[*Channel, int]
var channelSyncLock sync.RWMutex
func InitChannelCache() {
@@ -148,35 +147,51 @@ func InitChannelCache() {
for _, ability := range abilities {
groups[ability.Group] = true
}
newGroup2model2channels := make(map[string]map[string][]*Channel)
newGroup2model2channels := make(map[string]map[string][]weightedrand.Choice[*Channel, int])
for group := range groups {
newGroup2model2channels[group] = make(map[string][]*Channel)
newGroup2model2channels[group] = make(map[string][]weightedrand.Choice[*Channel, int])
}
for _, channel := range channels {
groups := strings.Split(channel.Group, ",")
groups := channel.GetGroups()
for _, group := range groups {
models := strings.Split(channel.Models, ",")
models := channel.GetModels()
weightMapping := channel.GetWeightMapping()
for _, model := range models {
if _, ok := newGroup2model2channels[group][model]; !ok {
newGroup2model2channels[group][model] = make([]*Channel, 0)
newGroup2model2channels[group][model] = make([]weightedrand.Choice[*Channel, int], 0)
}
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
weight, ok := weightMapping[model]
if weight < 0 || !ok {
// use default value if:
// weight < 0: invalid
// !ok: weight not set
weight = common.DefaultWeight
}
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], weightedrand.NewChoice(channel, weight))
}
}
}
// sort by priority
m := make(map[string]map[string]*weightedrand.Chooser[*Channel, int])
for group, model2channels := range newGroup2model2channels {
m[group] = make(map[string]*weightedrand.Chooser[*Channel, int])
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return channels[i].GetPriority() > channels[j].GetPriority()
})
newGroup2model2channels[group][model] = channels
if len(channels) == 0 {
common.SysError(fmt.Sprintf("no channel found for group %s model %s", group, model))
continue
}
c, err := weightedrand.NewChooser(channels...)
if err != nil {
common.SysError(fmt.Sprintf("failed to create chooser: %s", err.Error()))
continue
}
m[group][model] = c
}
}
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
group2model2channels = m
channelSyncLock.Unlock()
common.SysLog("channels synced from database")
}
@@ -196,20 +211,8 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
channels := group2model2channels[group][model]
if len(channels) == 0 {
if channels == nil {
return nil, errors.New("channel not found")
}
endIdx := len(channels)
// choose by priority
firstChannel := channels[0]
if firstChannel.GetPriority() > 0 {
for i := range channels {
if channels[i].GetPriority() != firstChannel.GetPriority() {
endIdx = i
break
}
}
}
idx := rand.Intn(endIdx)
return channels[idx], nil
return channels.Pick(), nil
}

View File

@@ -1,8 +1,10 @@
package model
import (
"gorm.io/gorm"
"encoding/json"
"one-api/common"
"gorm.io/gorm"
)
type Channel struct {
@@ -23,6 +25,7 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
WeightMapping *string `json:"weight_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
}
@@ -72,6 +75,51 @@ func BatchInsertChannels(channels []Channel) error {
return nil
}
func (channel *Channel) GetWeightMapping() (weightMapping map[string]int) {
if channel.WeightMapping == nil || *channel.WeightMapping == "" {
return
}
err := json.Unmarshal([]byte(*channel.WeightMapping), &weightMapping)
if err != nil {
common.SysError("failed to unmarshal weight mapping: " + err.Error())
}
return
}
func (channel *Channel) FixWeightMapping() {
var weightMapping map[string]int
if channel.WeightMapping == nil || *channel.WeightMapping == "" {
weightMapping = make(map[string]int)
} else {
err := json.Unmarshal([]byte(*channel.WeightMapping), &weightMapping)
if err != nil {
common.SysError("failed to marshal weight mapping: " + err.Error())
}
}
models := channel.GetModels()
for _, model := range models {
if _, ok := weightMapping[model]; !ok {
weightMapping[model] = common.DefaultWeight
}
}
jsonStr, err := json.Marshal(weightMapping)
if err != nil {
common.SysError("failed to marshal weight mapping: " + err.Error())
}
var result = string(jsonStr)
channel.WeightMapping = &result
}
func (channel *Channel) GetModels() []string {
return common.SplitDistinct(channel.Models, ",")
}
func (channel *Channel) GetGroups() []string {
return common.SplitDistinct(channel.Group, ",")
}
func (channel *Channel) GetPriority() int64 {
if channel.Priority == nil {
return 0

View File

@@ -71,6 +71,7 @@ func InitOptionMap() {
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMap["DefaultWeight"] = strconv.Itoa(common.DefaultWeight)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
@@ -205,6 +206,8 @@ func updateOptionMap(key string, value string) (err error) {
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
case "DefaultWeight":
common.DefaultWeight, _ = strconv.Atoi(value)
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":