mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-17 13:43:42 +08:00
feat: add weight_mapping for channel
This commit is contained in:
@@ -2,7 +2,6 @@ package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Ability struct {
|
||||
@@ -40,8 +39,8 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
}
|
||||
|
||||
func (channel *Channel) AddAbilities() error {
|
||||
models_ := strings.Split(channel.Models, ",")
|
||||
groups_ := strings.Split(channel.Group, ",")
|
||||
models_ := channel.GetModels()
|
||||
groups_ := channel.GetGroups()
|
||||
abilities := make([]Ability, 0, len(models_))
|
||||
for _, model := range models_ {
|
||||
for _, group := range groups_ {
|
||||
|
||||
@@ -4,13 +4,12 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mroth/weightedrand/v2"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -132,7 +131,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
return userEnabled, err
|
||||
}
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
var group2model2channels map[string]map[string]*weightedrand.Chooser[*Channel, int]
|
||||
var channelSyncLock sync.RWMutex
|
||||
|
||||
func InitChannelCache() {
|
||||
@@ -148,35 +147,51 @@ func InitChannelCache() {
|
||||
for _, ability := range abilities {
|
||||
groups[ability.Group] = true
|
||||
}
|
||||
newGroup2model2channels := make(map[string]map[string][]*Channel)
|
||||
newGroup2model2channels := make(map[string]map[string][]weightedrand.Choice[*Channel, int])
|
||||
for group := range groups {
|
||||
newGroup2model2channels[group] = make(map[string][]*Channel)
|
||||
newGroup2model2channels[group] = make(map[string][]weightedrand.Choice[*Channel, int])
|
||||
}
|
||||
for _, channel := range channels {
|
||||
groups := strings.Split(channel.Group, ",")
|
||||
groups := channel.GetGroups()
|
||||
for _, group := range groups {
|
||||
models := strings.Split(channel.Models, ",")
|
||||
models := channel.GetModels()
|
||||
weightMapping := channel.GetWeightMapping()
|
||||
for _, model := range models {
|
||||
if _, ok := newGroup2model2channels[group][model]; !ok {
|
||||
newGroup2model2channels[group][model] = make([]*Channel, 0)
|
||||
newGroup2model2channels[group][model] = make([]weightedrand.Choice[*Channel, int], 0)
|
||||
}
|
||||
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
|
||||
weight, ok := weightMapping[model]
|
||||
if weight < 0 || !ok {
|
||||
// use default value if:
|
||||
// weight < 0: invalid
|
||||
// !ok: weight not set
|
||||
weight = common.DefaultWeight
|
||||
}
|
||||
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], weightedrand.NewChoice(channel, weight))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sort by priority
|
||||
m := make(map[string]map[string]*weightedrand.Chooser[*Channel, int])
|
||||
for group, model2channels := range newGroup2model2channels {
|
||||
m[group] = make(map[string]*weightedrand.Chooser[*Channel, int])
|
||||
for model, channels := range model2channels {
|
||||
sort.Slice(channels, func(i, j int) bool {
|
||||
return channels[i].GetPriority() > channels[j].GetPriority()
|
||||
})
|
||||
newGroup2model2channels[group][model] = channels
|
||||
if len(channels) == 0 {
|
||||
common.SysError(fmt.Sprintf("no channel found for group %s model %s", group, model))
|
||||
continue
|
||||
}
|
||||
c, err := weightedrand.NewChooser(channels...)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to create chooser: %s", err.Error()))
|
||||
continue
|
||||
}
|
||||
m[group][model] = c
|
||||
}
|
||||
}
|
||||
|
||||
channelSyncLock.Lock()
|
||||
group2model2channels = newGroup2model2channels
|
||||
group2model2channels = m
|
||||
channelSyncLock.Unlock()
|
||||
common.SysLog("channels synced from database")
|
||||
}
|
||||
@@ -196,20 +211,8 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
channels := group2model2channels[group][model]
|
||||
if len(channels) == 0 {
|
||||
if channels == nil {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
endIdx := len(channels)
|
||||
// choose by priority
|
||||
firstChannel := channels[0]
|
||||
if firstChannel.GetPriority() > 0 {
|
||||
for i := range channels {
|
||||
if channels[i].GetPriority() != firstChannel.GetPriority() {
|
||||
endIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
idx := rand.Intn(endIdx)
|
||||
return channels[idx], nil
|
||||
return channels.Pick(), nil
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
@@ -23,6 +25,7 @@ type Channel struct {
|
||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||
WeightMapping *string `json:"weight_mapping" gorm:"type:varchar(1024);default:''"`
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||
}
|
||||
|
||||
@@ -72,6 +75,51 @@ func BatchInsertChannels(channels []Channel) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (channel *Channel) GetWeightMapping() (weightMapping map[string]int) {
|
||||
if channel.WeightMapping == nil || *channel.WeightMapping == "" {
|
||||
return
|
||||
}
|
||||
err := json.Unmarshal([]byte(*channel.WeightMapping), &weightMapping)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal weight mapping: " + err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (channel *Channel) FixWeightMapping() {
|
||||
var weightMapping map[string]int
|
||||
if channel.WeightMapping == nil || *channel.WeightMapping == "" {
|
||||
weightMapping = make(map[string]int)
|
||||
} else {
|
||||
err := json.Unmarshal([]byte(*channel.WeightMapping), &weightMapping)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal weight mapping: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
models := channel.GetModels()
|
||||
for _, model := range models {
|
||||
if _, ok := weightMapping[model]; !ok {
|
||||
weightMapping[model] = common.DefaultWeight
|
||||
}
|
||||
}
|
||||
|
||||
jsonStr, err := json.Marshal(weightMapping)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal weight mapping: " + err.Error())
|
||||
}
|
||||
var result = string(jsonStr)
|
||||
channel.WeightMapping = &result
|
||||
}
|
||||
|
||||
func (channel *Channel) GetModels() []string {
|
||||
return common.SplitDistinct(channel.Models, ",")
|
||||
}
|
||||
|
||||
func (channel *Channel) GetGroups() []string {
|
||||
return common.SplitDistinct(channel.Group, ",")
|
||||
}
|
||||
|
||||
func (channel *Channel) GetPriority() int64 {
|
||||
if channel.Priority == nil {
|
||||
return 0
|
||||
|
||||
@@ -71,6 +71,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["ChatLink"] = common.ChatLink
|
||||
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
|
||||
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
|
||||
common.OptionMap["DefaultWeight"] = strconv.Itoa(common.DefaultWeight)
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
loadOptionsFromDatabase()
|
||||
}
|
||||
@@ -205,6 +206,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.PreConsumedQuota, _ = strconv.Atoi(value)
|
||||
case "RetryTimes":
|
||||
common.RetryTimes, _ = strconv.Atoi(value)
|
||||
case "DefaultWeight":
|
||||
common.DefaultWeight, _ = strconv.Atoi(value)
|
||||
case "ModelRatio":
|
||||
err = common.UpdateModelRatioByJSONString(value)
|
||||
case "GroupRatio":
|
||||
|
||||
Reference in New Issue
Block a user