diff --git a/controller/relay.go b/controller/relay.go index 03853c1..a04c85a 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -29,6 +29,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode fallthrough case relayconstant.RelayModeAudioTranscription: err = relay.AudioHelper(c, relayMode) + case relayconstant.RelayModeRerank: + err = relay.RerankHelper(c, relayMode) default: err = relay.TextHelper(c) } diff --git a/dto/rerank.go b/dto/rerank.go new file mode 100644 index 0000000..0ee44b1 --- /dev/null +++ b/dto/rerank.go @@ -0,0 +1,19 @@ +package dto + +type RerankRequest struct { + Documents []any `json:"documents"` + Query string `json:"query"` + Model string `json:"model"` + TopN int `json:"top_n"` +} + +type RerankResponseDocument struct { + Document any `json:"document"` + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` +} + +type RerankResponse struct { + Results []RerankResponseDocument `json:"results"` + Usage Usage `json:"usage"` +} diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index d87f476..e222a70 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -11,9 +11,11 @@ import ( type Adaptor interface { // Init IsStream bool Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) + InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) GetRequestURL(info *relaycommon.RelayInfo) (string, error) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) + ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) GetModelList() []string diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index bfe83db..fbaf546 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -15,6 +15,9 @@ import ( type Adaptor struct { } +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { } @@ -53,6 +56,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen } } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 23c69db..4de3a3a 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -20,6 +20,11 @@ type Adaptor struct { RequestMode int } +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { + //TODO implement me + +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { if strings.HasPrefix(info.UpstreamModelName, "claude-3") { a.RequestMode = RequestModeMessage @@ -53,6 +58,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen return claudeReq, err } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return nil, nil } diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 23f2d06..17f5384 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -16,6 +16,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { + //TODO implement me + +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { } @@ -108,6 +113,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen } } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index d302265..4623318 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -21,6 +21,11 @@ type Adaptor struct { RequestMode int } +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { + //TODO implement me + +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { if strings.HasPrefix(info.UpstreamModelName, "claude-3") { a.RequestMode = RequestModeMessage @@ -59,6 +64,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen } } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index cd01634..b5f3521 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -8,16 +8,24 @@ import ( "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" + "one-api/relay/constant" ) type Adaptor struct { } +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil + if info.RelayMode == constant.RelayModeRerank { + return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + } else { + return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil + } } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { @@ -34,11 +42,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return requestConvertRerank2Cohere(request), nil +} + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = cohereStreamHandler(c, resp, info) + if info.RelayMode == constant.RelayModeRerank { + err, usage = cohereRerankHandler(c, resp, info) } else { - err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens) + if info.IsStream { + err, usage = cohereStreamHandler(c, resp, info) + } else { + err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens) + } } return } diff --git a/relay/channel/cohere/constant.go b/relay/channel/cohere/constant.go index 189d234..8f34e4f 100644 --- a/relay/channel/cohere/constant.go +++ b/relay/channel/cohere/constant.go @@ -2,6 +2,7 @@ package cohere var ModelList = []string{ "command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly", + "rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0", } var ChannelName = "cohere" diff --git a/relay/channel/cohere/dto.go b/relay/channel/cohere/dto.go index 958343c..fc6c445 100644 --- a/relay/channel/cohere/dto.go +++ b/relay/channel/cohere/dto.go @@ -1,5 +1,7 @@ package cohere +import "one-api/dto" + type CohereRequest struct { Model string `json:"model"` ChatHistory []ChatHistory `json:"chat_history"` @@ -28,6 +30,19 @@ type CohereResponseResult struct { Meta CohereMeta `json:"meta"` } +type CohereRerankRequest struct { + Documents []any `json:"documents"` + Query string `json:"query"` + Model string `json:"model"` + TopN int `json:"top_n"` + ReturnDocuments bool `json:"return_documents"` +} + +type CohereRerankResponseResult struct { + Results []dto.RerankResponseDocument `json:"results"` + Meta CohereMeta `json:"meta"` +} + type CohereMeta struct { //Tokens CohereTokens `json:"tokens"` BilledUnits CohereBilledUnits `json:"billed_units"` diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index cc424b0..d20acb6 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -47,6 +47,20 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest { return &cohereReq } +func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest { + cohereReq := CohereRerankRequest{ + Query: rerankRequest.Query, + Documents: rerankRequest.Documents, + Model: rerankRequest.Model, + TopN: rerankRequest.TopN, + ReturnDocuments: true, + } + for _, doc := range rerankRequest.Documents { + cohereReq.Documents = append(cohereReq.Documents, doc) + } + return &cohereReq +} + func stopReasonCohere2OpenAI(reason string) string { switch reason { case "COMPLETE": @@ -194,3 +208,42 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt _, err = c.Writer.Write(jsonResponse) return nil, &usage } + +func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var cohereResp CohereRerankResponseResult + err = json.Unmarshal(responseBody, &cohereResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + usage := dto.Usage{} + if cohereResp.Meta.BilledUnits.InputTokens == 0 { + usage.PromptTokens = info.PromptTokens + usage.CompletionTokens = 0 + usage.TotalTokens = info.PromptTokens + } else { + usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens + usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens + usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens + } + + var rerankResp dto.RerankResponse + rerankResp.Results = cohereResp.Results + rerankResp.Usage = usage + + jsonResponse, err := json.Marshal(rerankResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 99f3792..a54b95b 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -14,6 +14,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { + //TODO implement me + +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { } @@ -34,6 +39,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen return requestOpenAI2Dify(*request), nil } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 875361c..f51ae3f 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -15,6 +15,9 @@ import ( type Adaptor struct { } +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { } @@ -56,6 +59,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen return CovertGemini2OpenAI(*request), nil } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index e1225ee..4bf1d61 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -16,6 +16,9 @@ import ( type Adaptor struct { } +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { } @@ -45,6 +48,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen } } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 58f5ab5..0c1ce25 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -22,6 +22,13 @@ type Adaptor struct { ChannelType int } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { a.ChannelType = info.ChannelType } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 4f59a44..8f6dd0a 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -15,6 +15,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { + //TODO implement me + +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { } @@ -35,6 +40,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen return request, nil } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index ef87df8..c3972d5 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -16,6 +16,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { + //TODO implement me + +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { } @@ -39,6 +44,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen return requestOpenAI2Perplexity(*request), nil } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 33eda3f..42d4f12 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -22,6 +22,11 @@ type Adaptor struct { Timestamp int64 } +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { + //TODO implement me + +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { a.Action = "ChatCompletions" a.Version = "2023-09-01" @@ -57,6 +62,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen return tencentRequest, nil } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 79a4b12..9852aa1 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -16,6 +16,11 @@ type Adaptor struct { request *dto.GeneralOpenAIRequest } +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { + //TODO implement me + +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { } @@ -36,6 +41,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen return request, nil } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { // xunfei's request is not http request, so we don't need to do anything here dummyResp := &http.Response{} diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 6f2d186..0893a83 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -14,6 +14,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { + //TODO implement me + +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { } @@ -42,6 +47,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen return requestOpenAI2Zhipu(*request), nil } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 040cf38..eaf3087 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -16,6 +16,11 @@ import ( type Adaptor struct { } +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { + //TODO implement me + +} + func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { } @@ -40,6 +45,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen return requestOpenAI2Zhipu(*request), nil } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index fa19f50..ed15b08 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -32,6 +32,7 @@ const ( RelayModeSunoFetch RelayModeSunoFetchByID RelayModeSunoSubmit + RelayModeRerank ) func Path2RelayMode(path string) int { @@ -56,6 +57,8 @@ func Path2RelayMode(path string) int { relayMode = RelayModeAudioTranscription } else if strings.HasPrefix(path, "/v1/audio/translations") { relayMode = RelayModeAudioTranslation + } else if strings.HasPrefix(path, "/v1/rerank") { + relayMode = RelayModeRerank } return relayMode } diff --git a/relay/relay-text.go b/relay/relay-text.go index 8673286..28b5d35 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -182,7 +182,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success) + postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success) return nil } @@ -272,7 +272,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu } } -func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, +func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64, usePrice bool) { @@ -281,7 +281,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe completionTokens := usage.CompletionTokens tokenName := ctx.GetString("token_name") - completionRatio := common.GetCompletionRatio(textRequest.Model) + completionRatio := common.GetCompletionRatio(modelName) quota := 0 if !usePrice { @@ -307,7 +307,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota)) + common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) } else { //if sensitiveResp != nil { // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) @@ -327,13 +328,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - logModel := textRequest.Model + logModel := modelName if strings.HasPrefix(logModel, "gpt-4-gizmo") { logModel = "gpt-4-gizmo-*" - logContent += fmt.Sprintf(",模型 %s", textRequest.Model) + logContent += fmt.Sprintf(",模型 %s", modelName) } other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice) - model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other) + model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, + tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other) //if quota != 0 { // diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go new file mode 100644 index 0000000..e32ca88 --- /dev/null +++ b/relay/relay_rerank.go @@ -0,0 +1,104 @@ +package relay + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" +) + +func getRerankPromptToken(rerankRequest dto.RerankRequest) int { + token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model) + for _, document := range rerankRequest.Documents { + tkm, err := service.CountTokenInput(document, rerankRequest.Model) + if err == nil { + token += tkm + } + } + return token +} + +func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { + relayInfo := relaycommon.GenRelayInfo(c) + + var rerankRequest *dto.RerankRequest + err := common.UnmarshalBodyReusable(c, &rerankRequest) + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) + return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) + } + if rerankRequest.Query == "" { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest) + } + if len(rerankRequest.Documents) == 0 { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest) + } + relayInfo.UpstreamModelName = rerankRequest.Model + modelPrice, success := common.GetModelPrice(rerankRequest.Model, false) + groupRatio := common.GetGroupRatio(relayInfo.Group) + + var preConsumedQuota int + var ratio float64 + var modelRatio float64 + + promptToken := getRerankPromptToken(*rerankRequest) + if !success { + preConsumedTokens := promptToken + modelRatio = common.GetModelRatio(rerankRequest.Model) + ratio = modelRatio * groupRatio + preConsumedQuota = int(float64(preConsumedTokens) * ratio) + } else { + preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) + } + relayInfo.PromptTokens = promptToken + + // pre-consume quota 预消耗配额 + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) + if openaiErr != nil { + return openaiErr + } + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.InitRerank(relayInfo, *rerankRequest) + + convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) + } + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) + } + requestBody := bytes.NewBuffer(jsonData) + statusCodeMappingStr := c.GetString("status_code_mapping") + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + if err != nil { + return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + if resp != nil { + if resp.StatusCode != http.StatusOK { + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + openaiErr := service.RelayErrorHandler(resp) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + } + + usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success) + return nil +} diff --git a/router/relay-router.go b/router/relay-router.go index 3ad9e37..2bf2ca2 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -42,6 +42,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) relayV1Router.POST("/moderations", controller.Relay) + relayV1Router.POST("/rerank", controller.Relay) } relayMjRouter := router.Group("/mj")