From dbe3930a8cf73539c3632256a70c1c5bd7981bab Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Fri, 1 Mar 2024 13:49:23 +0000 Subject: [PATCH] fix: Switch to channel-ratio - Use channel ratios instead of group ratios in all applicable places - Start using the lowest channel ratio of the specified channel's groups --- middleware/distributor.go | 10 ++++++++++ relay/controller/audio.go | 5 +++-- relay/controller/image.go | 12 +++++++----- relay/controller/text.go | 3 ++- relay/util/relay_meta.go | 2 ++ 5 files changed, 24 insertions(+), 8 deletions(-) diff --git a/middleware/distributor.go b/middleware/distributor.go index aeb2796a..0fda0670 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -85,6 +85,16 @@ func Distribute() func(c *gin.Context) { } func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { + // set minimal group ratio as channel_ratio + var minimalRatio float64 + for _, grp := range strings.Split(channel.Group, ",") { + v := common.GetGroupRatio(grp) + if minimalRatio == 0 || v < minimalRatio { + minimalRatio = v + } + } + c.Set("channel_ratio", minimalRatio) + c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) diff --git a/relay/controller/audio.go b/relay/controller/audio.go index ee8771c9..bd39944f 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -28,7 +28,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") userId := c.GetInt("id") - group := c.GetString("group") + // group := c.GetString("group") tokenName := c.GetString("token_name") var ttsRequest openai.TextToSpeechRequest @@ -47,7 +47,8 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } modelRatio := common.GetModelRatio(audioModel) - groupRatio := common.GetGroupRatio(group) + // groupRatio := common.GetGroupRatio(group) + groupRatio := c.GetFloat64("channel_ratio") ratio := modelRatio * groupRatio var quota int var preConsumedQuota int diff --git a/relay/controller/image.go b/relay/controller/image.go index 6ec368f5..4e0ed172 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -6,15 +6,16 @@ import ( "encoding/json" "errors" "fmt" + "io" + "net/http" + "strings" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" - "io" - "net/http" - "strings" "github.com/gin-gonic/gin" ) @@ -37,7 +38,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") userId := c.GetInt("id") - group := c.GetString("group") + // group := c.GetString("group") var imageRequest openai.ImageRequest err := common.UnmarshalBodyReusable(c, &imageRequest) @@ -131,7 +132,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } modelRatio := common.GetModelRatio(imageModel) - groupRatio := common.GetGroupRatio(group) + // groupRatio := common.GetGroupRatio(group) + groupRatio := c.GetFloat64("channel_ratio") ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(userId) diff --git a/relay/controller/text.go b/relay/controller/text.go index e19b571f..46c62a3c 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -35,7 +35,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { meta.ActualModelName = textRequest.Model // get model ratio & group ratio modelRatio := common.GetModelRatio(textRequest.Model) - groupRatio := common.GetGroupRatio(meta.Group) + // groupRatio := common.GetGroupRatio(meta.Group) + groupRatio := meta.ChannelRatio ratio := modelRatio * groupRatio // pre-consume quota promptTokens := getPromptTokens(textRequest, meta.Mode) diff --git a/relay/util/relay_meta.go b/relay/util/relay_meta.go index 31b9d2b4..17135816 100644 --- a/relay/util/relay_meta.go +++ b/relay/util/relay_meta.go @@ -26,6 +26,7 @@ type RelayMeta struct { ActualModelName string RequestURLPath string PromptTokens int // only for DoResponse + ChannelRatio float64 } func GetRelayMeta(c *gin.Context) *RelayMeta { @@ -43,6 +44,7 @@ func GetRelayMeta(c *gin.Context) *RelayMeta { APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), Config: nil, RequestURLPath: c.Request.URL.String(), + ChannelRatio: c.GetFloat64("channel_ratio"), } if meta.ChannelType == common.ChannelTypeAzure { meta.APIVersion = GetAzureAPIVersion(c)