From 6b71db7ce2584d1066477dba002be2d6c498322d Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 20 Apr 2024 21:05:23 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E7=8A=B6=E6=80=81=E7=A0=81=E5=A4=8D?= =?UTF-8?q?=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/distributor.go | 1 + model/channel.go | 8 +++++++ relay/relay-text.go | 8 ++++++- service/error.go | 19 +++++++++++++++ web/src/pages/Channel/EditChannel.js | 35 ++++++++++++++++++++++++++++ 5 files changed, 70 insertions(+), 1 deletion(-) diff --git a/middleware/distributor.go b/middleware/distributor.go index 108c783..ae5707f 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -177,6 +177,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode } c.Set("auto_ban", ban) c.Set("model_mapping", channel.GetModelMapping()) + c.Set("status_code_mapping", channel.GetStatusCodeMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.GetBaseURL()) // TODO: api_version统一 diff --git a/model/channel.go b/model/channel.go index 3e30ad4..c0c21c0 100644 --- a/model/channel.go +++ b/model/channel.go @@ -25,6 +25,7 @@ type Channel struct { Group string `json:"group" gorm:"type:varchar(64);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` + StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` AutoBan *int `json:"auto_ban" gorm:"default:1"` } @@ -153,6 +154,13 @@ func (channel *Channel) GetModelMapping() string { return *channel.ModelMapping } +func (channel *Channel) GetStatusCodeMapping() string { + if channel.StatusCodeMapping == nil { + return "" + } + return *channel.StatusCodeMapping +} + func (channel *Channel) Insert() error { var err error err = DB.Create(channel).Error diff --git a/relay/relay-text.go b/relay/relay-text.go index 71a47c2..6026560 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -154,6 +154,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { requestBody = bytes.NewBuffer(jsonData) } + statusCodeMappingStr := c.GetString("status_code_mapping") resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) @@ -162,12 +163,17 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { if resp.StatusCode != http.StatusOK { returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) - return service.RelayErrorHandler(resp) + openaiErr := service.RelayErrorHandler(resp) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr } usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) if openaiErr != nil { returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice) diff --git a/service/error.go b/service/error.go index 39eb0f9..4b00f37 100644 --- a/service/error.go +++ b/service/error.go @@ -86,3 +86,22 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW } return } + +func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMappingStr string) { + if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" { + return + } + statusCodeMapping := make(map[string]string) + err := json.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping) + if err != nil { + return + } + if openaiErr.StatusCode == http.StatusOK { + return + } + codeStr := strconv.Itoa(openaiErr.StatusCode) + if _, ok := statusCodeMapping[codeStr]; ok { + intCode, _ := strconv.Atoi(statusCodeMapping[codeStr]) + openaiErr.StatusCode = intCode + } +} diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 0fe6e2b..de00d89 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -29,6 +29,10 @@ const MODEL_MAPPING_EXAMPLE = { 'gpt-4-32k-0314': 'gpt-4-32k', }; +const STATUS_CODE_MAPPING_EXAMPLE = { + 400: '500', +}; + function type2secretPrompt(type) { // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') switch (type) { @@ -61,6 +65,7 @@ const EditChannel = (props) => { base_url: '', other: '', model_mapping: '', + status_code_mapping: '', models: [], auto_ban: 1, test_model: '', @@ -629,6 +634,36 @@ const EditChannel = (props) => { > 填入模板 +