From 5acf0745412caf80b05d97477e080df62cbf0a89 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sat, 3 Aug 2024 17:32:28 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E4=BC=98=E5=8C=96=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E7=A6=81=E7=94=A8=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel-test.go | 2 +- controller/relay.go | 19 +++++++++++++------ middleware/distributor.go | 7 +------ model/channel.go | 7 +++++++ relay/relay-mj.go | 2 +- 5 files changed, 23 insertions(+), 14 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index fe27978..95c4a60 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -240,7 +240,7 @@ func testAllChannels(notify bool) error { } // parse *int to bool - if channel.AutoBan != nil && *channel.AutoBan == 0 { + if !channel.GetAutoBan() { ban = false } diff --git a/controller/relay.go b/controller/relay.go index 30217f0..3a6fb6f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -59,7 +59,7 @@ func Relay(c *gin.Context) { return // 成功处理请求,直接返回 } - go processChannelError(c, channel.Id, channel.Type, channel.Name, openaiErr) + go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr) if !shouldRetry(c, openaiErr, common.RetryTimes-i) { break @@ -97,10 +97,16 @@ func addUsedChannel(c *gin.Context, channelId int) { func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) { if retryCount == 0 { + autoBan := c.GetBool("auto_ban") + autoBanInt := 1 + if !autoBan { + autoBanInt = 0 + } return &model.Channel{ - Id: c.GetInt("channel_id"), - Type: c.GetInt("channel_type"), - Name: c.GetString("channel_name"), + Id: c.GetInt("channel_id"), + Type: c.GetInt("channel_type"), + Name: c.GetString("channel_name"), + AutoBan: &autoBanInt, }, nil } channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount) @@ -154,8 +160,9 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry return true } -func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, err *dto.OpenAIErrorWithStatusCode) { - autoBan := c.GetBool("auto_ban") +func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) { + // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 + // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message)) if service.ShouldDisableChannel(channelType, err) && autoBan { service.DisableChannel(channelId, channelName, err.Error.Message) diff --git a/middleware/distributor.go b/middleware/distributor.go index fad868d..1be3b31 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -187,15 +187,10 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("channel_type", channel.Type) - ban := true - // parse *int to bool - if channel.AutoBan != nil && *channel.AutoBan == 0 { - ban = false - } if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { c.Set("channel_organization", *channel.OpenAIOrganization) } - c.Set("auto_ban", ban) + c.Set("auto_ban", channel.GetAutoBan()) c.Set("model_mapping", channel.GetModelMapping()) c.Set("status_code_mapping", channel.GetStatusCodeMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) diff --git a/model/channel.go b/model/channel.go index 7db3f07..3f9d9ed 100644 --- a/model/channel.go +++ b/model/channel.go @@ -61,6 +61,13 @@ func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) { channel.OtherInfo = string(otherInfoBytes) } +func (channel *Channel) GetAutoBan() bool { + if channel.AutoBan == nil { + return false + } + return *channel.AutoBan == 1 +} + func (channel *Channel) Save() error { return DB.Save(channel).Error } diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 73ea468..4dd81c5 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -549,7 +549,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if err != nil { common.SysError("get_channel_null: " + err.Error()) } - if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled { + if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled { model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance") } }