From 7297c6269f1a7eff1774b69e2c8fea7b26ae5c21 Mon Sep 17 00:00:00 2001 From: papersnake Date: Sat, 2 Mar 2024 12:26:39 +0000 Subject: [PATCH] fix: add missing upstreamModelName --- middleware/distributor.go | 2 ++ relay/common/relay_info.go | 2 ++ 2 files changed, 4 insertions(+) diff --git a/middleware/distributor.go b/middleware/distributor.go index 1ca43dd..4930961 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -116,6 +116,8 @@ func Distribute() func(c *gin.Context) { abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) return } + + c.Set("model_name", modelRequest.Model) } c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 8051cec..270a2d8 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -34,6 +34,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { userId := c.GetInt("id") group := c.GetString("group") tokenUnlimited := c.GetBool("token_unlimited_quota") + upstreamModelName := c.GetString("model_name") startTime := time.Now() apiType := constant.ChannelType2APIType(channelType) @@ -52,6 +53,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { ApiType: apiType, ApiVersion: c.GetString("api_version"), ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + UpstreamModelName: upstreamModelName, } if info.BaseUrl == "" { info.BaseUrl = common.ChannelBaseURLs[channelType]