feat: 状态码复写

This commit is contained in:
CaIon 2024-04-20 21:05:23 +08:00
parent b8fb351fd8
commit 6b71db7ce2
5 changed files with 70 additions and 1 deletions

View File

@ -177,6 +177,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
} }
c.Set("auto_ban", ban) c.Set("auto_ban", ban)
c.Set("model_mapping", channel.GetModelMapping()) c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
// TODO: api_version统一 // TODO: api_version统一

View File

@ -25,6 +25,7 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(64);default:'default'"` Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"` Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"` AutoBan *int `json:"auto_ban" gorm:"default:1"`
} }
@ -153,6 +154,13 @@ func (channel *Channel) GetModelMapping() string {
return *channel.ModelMapping return *channel.ModelMapping
} }
func (channel *Channel) GetStatusCodeMapping() string {
if channel.StatusCodeMapping == nil {
return ""
}
return *channel.StatusCodeMapping
}
func (channel *Channel) Insert() error { func (channel *Channel) Insert() error {
var err error var err error
err = DB.Create(channel).Error err = DB.Create(channel).Error

View File

@ -154,6 +154,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
requestBody = bytes.NewBuffer(jsonData) requestBody = bytes.NewBuffer(jsonData)
} }
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody) resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
@ -162,12 +163,17 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
return service.RelayErrorHandler(resp) openaiErr := service.RelayErrorHandler(resp)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
} }
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice) postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)

View File

@ -86,3 +86,22 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW
} }
return return
} }
func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMappingStr string) {
if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" {
return
}
statusCodeMapping := make(map[string]string)
err := json.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
if err != nil {
return
}
if openaiErr.StatusCode == http.StatusOK {
return
}
codeStr := strconv.Itoa(openaiErr.StatusCode)
if _, ok := statusCodeMapping[codeStr]; ok {
intCode, _ := strconv.Atoi(statusCodeMapping[codeStr])
openaiErr.StatusCode = intCode
}
}

View File

@ -29,6 +29,10 @@ const MODEL_MAPPING_EXAMPLE = {
'gpt-4-32k-0314': 'gpt-4-32k', 'gpt-4-32k-0314': 'gpt-4-32k',
}; };
const STATUS_CODE_MAPPING_EXAMPLE = {
400: '500',
};
function type2secretPrompt(type) { function type2secretPrompt(type) {
// inputs.type === 15 ? '按照如下格式输入APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') // inputs.type === 15 ? '按照如下格式输入APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
switch (type) { switch (type) {
@ -61,6 +65,7 @@ const EditChannel = (props) => {
base_url: '', base_url: '',
other: '', other: '',
model_mapping: '', model_mapping: '',
status_code_mapping: '',
models: [], models: [],
auto_ban: 1, auto_ban: 1,
test_model: '', test_model: '',
@ -629,6 +634,36 @@ const EditChannel = (props) => {
> >
填入模板 填入模板
</Typography.Text> </Typography.Text>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>
状态码复写仅影响本地判断不修改返回到上游的状态码
</Typography.Text>
</div>
<TextArea
placeholder={`此项可选用于复写返回的状态码比如将claude渠道的400错误复写为500用于重试请勿滥用该功能例如\n${JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2)}`}
name='status_code_mapping'
onChange={(value) => {
handleInputChange('status_code_mapping', value);
}}
autosize
value={inputs.status_code_mapping}
autoComplete='new-password'
/>
<Typography.Text
style={{
color: 'rgba(var(--semi-blue-5), 1)',
userSelect: 'none',
cursor: 'pointer',
}}
onClick={() => {
handleInputChange(
'status_code_mapping',
JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2),
);
}}
>
填入模板
</Typography.Text>
<div style={{ marginTop: 10 }}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>密钥</Typography.Text> <Typography.Text strong>密钥</Typography.Text>
</div> </div>