From 144513f1d841fd178403b1ec6b117f06b053b442 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 23 Aug 2024 23:21:37 +0800 Subject: [PATCH] feat: rerank model mapping (close #444) --- relay/relay_rerank.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go index 9885fd3..4242155 100644 --- a/relay/relay_rerank.go +++ b/relay/relay_rerank.go @@ -38,6 +38,23 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode if len(rerankRequest.Documents) == 0 { return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest) } + + // map model name + modelMapping := c.GetString("model_mapping") + //isModelMapped := false + if modelMapping != "" && modelMapping != "{}" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap[rerankRequest.Model] != "" { + rerankRequest.Model = modelMap[rerankRequest.Model] + // set upstream model name + //isModelMapped = true + } + } + relayInfo.UpstreamModelName = rerankRequest.Model modelPrice, success := common.GetModelPrice(rerankRequest.Model, false) groupRatio := common.GetGroupRatio(relayInfo.Group)