mirror of
				https://github.com/linux-do/new-api.git
				synced 2025-11-04 21:33:41 +08:00 
			
		
		
		
	feat: support cohere rerank
This commit is contained in:
		@@ -29,6 +29,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case relayconstant.RelayModeAudioTranscription:
 | 
			
		||||
		err = relay.AudioHelper(c, relayMode)
 | 
			
		||||
	case relayconstant.RelayModeRerank:
 | 
			
		||||
		err = relay.RerankHelper(c, relayMode)
 | 
			
		||||
	default:
 | 
			
		||||
		err = relay.TextHelper(c)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										19
									
								
								dto/rerank.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								dto/rerank.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
package dto
 | 
			
		||||
 | 
			
		||||
type RerankRequest struct {
 | 
			
		||||
	Documents []any  `json:"documents"`
 | 
			
		||||
	Query     string `json:"query"`
 | 
			
		||||
	Model     string `json:"model"`
 | 
			
		||||
	TopN      int    `json:"top_n"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RerankResponseDocument struct {
 | 
			
		||||
	Document       any     `json:"document"`
 | 
			
		||||
	Index          int     `json:"index"`
 | 
			
		||||
	RelevanceScore float64 `json:"relevance_score"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RerankResponse struct {
 | 
			
		||||
	Results []RerankResponseDocument `json:"results"`
 | 
			
		||||
	Usage   Usage                    `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
@@ -11,9 +11,11 @@ import (
 | 
			
		||||
type Adaptor interface {
 | 
			
		||||
	// Init IsStream bool
 | 
			
		||||
	Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
 | 
			
		||||
	InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
 | 
			
		||||
	GetRequestURL(info *relaycommon.RelayInfo) (string, error)
 | 
			
		||||
	SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
 | 
			
		||||
	ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
 | 
			
		||||
	ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
 | 
			
		||||
	DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
 | 
			
		||||
	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
 | 
			
		||||
	GetModelList() []string
 | 
			
		||||
 
 | 
			
		||||
@@ -15,6 +15,9 @@ import (
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@@ -53,6 +56,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -20,6 +20,11 @@ type Adaptor struct {
 | 
			
		||||
	RequestMode int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
 | 
			
		||||
	//TODO implement me
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
	if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
 | 
			
		||||
		a.RequestMode = RequestModeMessage
 | 
			
		||||
@@ -53,6 +58,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
	return claudeReq, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -16,6 +16,11 @@ import (
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
 | 
			
		||||
	//TODO implement me
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@@ -108,6 +113,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -21,6 +21,11 @@ type Adaptor struct {
 | 
			
		||||
	RequestMode int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
 | 
			
		||||
	//TODO implement me
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
	if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
 | 
			
		||||
		a.RequestMode = RequestModeMessage
 | 
			
		||||
@@ -59,6 +64,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -8,16 +8,24 @@ import (
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/relay/channel"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/relay/constant"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
	return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
 | 
			
		||||
	if info.RelayMode == constant.RelayModeRerank {
 | 
			
		||||
		return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
 | 
			
		||||
	} else {
 | 
			
		||||
		return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
@@ -34,11 +42,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 | 
			
		||||
	return requestConvertRerank2Cohere(request), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		err, usage = cohereStreamHandler(c, resp, info)
 | 
			
		||||
	if info.RelayMode == constant.RelayModeRerank {
 | 
			
		||||
		err, usage = cohereRerankHandler(c, resp, info)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
		if info.IsStream {
 | 
			
		||||
			err, usage = cohereStreamHandler(c, resp, info)
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package cohere
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
 | 
			
		||||
	"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var ChannelName = "cohere"
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,7 @@
 | 
			
		||||
package cohere
 | 
			
		||||
 | 
			
		||||
import "one-api/dto"
 | 
			
		||||
 | 
			
		||||
type CohereRequest struct {
 | 
			
		||||
	Model       string        `json:"model"`
 | 
			
		||||
	ChatHistory []ChatHistory `json:"chat_history"`
 | 
			
		||||
@@ -28,6 +30,19 @@ type CohereResponseResult struct {
 | 
			
		||||
	Meta         CohereMeta `json:"meta"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CohereRerankRequest struct {
 | 
			
		||||
	Documents       []any  `json:"documents"`
 | 
			
		||||
	Query           string `json:"query"`
 | 
			
		||||
	Model           string `json:"model"`
 | 
			
		||||
	TopN            int    `json:"top_n"`
 | 
			
		||||
	ReturnDocuments bool   `json:"return_documents"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CohereRerankResponseResult struct {
 | 
			
		||||
	Results []dto.RerankResponseDocument `json:"results"`
 | 
			
		||||
	Meta    CohereMeta                   `json:"meta"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CohereMeta struct {
 | 
			
		||||
	//Tokens CohereTokens `json:"tokens"`
 | 
			
		||||
	BilledUnits CohereBilledUnits `json:"billed_units"`
 | 
			
		||||
 
 | 
			
		||||
@@ -47,6 +47,20 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
 | 
			
		||||
	return &cohereReq
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
 | 
			
		||||
	cohereReq := CohereRerankRequest{
 | 
			
		||||
		Query:           rerankRequest.Query,
 | 
			
		||||
		Documents:       rerankRequest.Documents,
 | 
			
		||||
		Model:           rerankRequest.Model,
 | 
			
		||||
		TopN:            rerankRequest.TopN,
 | 
			
		||||
		ReturnDocuments: true,
 | 
			
		||||
	}
 | 
			
		||||
	for _, doc := range rerankRequest.Documents {
 | 
			
		||||
		cohereReq.Documents = append(cohereReq.Documents, doc)
 | 
			
		||||
	}
 | 
			
		||||
	return &cohereReq
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func stopReasonCohere2OpenAI(reason string) string {
 | 
			
		||||
	switch reason {
 | 
			
		||||
	case "COMPLETE":
 | 
			
		||||
@@ -194,3 +208,42 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	var cohereResp CohereRerankResponseResult
 | 
			
		||||
	err = json.Unmarshal(responseBody, &cohereResp)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	usage := dto.Usage{}
 | 
			
		||||
	if cohereResp.Meta.BilledUnits.InputTokens == 0 {
 | 
			
		||||
		usage.PromptTokens = info.PromptTokens
 | 
			
		||||
		usage.CompletionTokens = 0
 | 
			
		||||
		usage.TotalTokens = info.PromptTokens
 | 
			
		||||
	} else {
 | 
			
		||||
		usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
 | 
			
		||||
		usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
 | 
			
		||||
		usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var rerankResp dto.RerankResponse
 | 
			
		||||
	rerankResp.Results = cohereResp.Results
 | 
			
		||||
	rerankResp.Usage = usage
 | 
			
		||||
 | 
			
		||||
	jsonResponse, err := json.Marshal(rerankResp)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -14,6 +14,11 @@ import (
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
 | 
			
		||||
	//TODO implement me
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -34,6 +39,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
	return requestOpenAI2Dify(*request), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -15,6 +15,9 @@ import (
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -56,6 +59,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
	return CovertGemini2OpenAI(*request), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -16,6 +16,9 @@ import (
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -45,6 +48,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -22,6 +22,13 @@ type Adaptor struct {
 | 
			
		||||
	ChannelType int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
	a.ChannelType = info.ChannelType
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -15,6 +15,11 @@ import (
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
 | 
			
		||||
	//TODO implement me
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -35,6 +40,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
	return request, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -16,6 +16,11 @@ import (
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
 | 
			
		||||
	//TODO implement me
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -39,6 +44,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
	return requestOpenAI2Perplexity(*request), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -22,6 +22,11 @@ type Adaptor struct {
 | 
			
		||||
	Timestamp int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
 | 
			
		||||
	//TODO implement me
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
	a.Action = "ChatCompletions"
 | 
			
		||||
	a.Version = "2023-09-01"
 | 
			
		||||
@@ -57,6 +62,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
	return tencentRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -16,6 +16,11 @@ type Adaptor struct {
 | 
			
		||||
	request *dto.GeneralOpenAIRequest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
 | 
			
		||||
	//TODO implement me
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -36,6 +41,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
	return request, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	// xunfei's request is not http request, so we don't need to do anything here
 | 
			
		||||
	dummyResp := &http.Response{}
 | 
			
		||||
 
 | 
			
		||||
@@ -14,6 +14,11 @@ import (
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
 | 
			
		||||
	//TODO implement me
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -42,6 +47,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
	return requestOpenAI2Zhipu(*request), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -16,6 +16,11 @@ import (
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
 | 
			
		||||
	//TODO implement me
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -40,6 +45,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
	return requestOpenAI2Zhipu(*request), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -32,6 +32,7 @@ const (
 | 
			
		||||
	RelayModeSunoFetch
 | 
			
		||||
	RelayModeSunoFetchByID
 | 
			
		||||
	RelayModeSunoSubmit
 | 
			
		||||
	RelayModeRerank
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func Path2RelayMode(path string) int {
 | 
			
		||||
@@ -56,6 +57,8 @@ func Path2RelayMode(path string) int {
 | 
			
		||||
		relayMode = RelayModeAudioTranscription
 | 
			
		||||
	} else if strings.HasPrefix(path, "/v1/audio/translations") {
 | 
			
		||||
		relayMode = RelayModeAudioTranslation
 | 
			
		||||
	} else if strings.HasPrefix(path, "/v1/rerank") {
 | 
			
		||||
		relayMode = RelayModeRerank
 | 
			
		||||
	}
 | 
			
		||||
	return relayMode
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -182,7 +182,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 | 
			
		||||
		return openaiErr
 | 
			
		||||
	}
 | 
			
		||||
	postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
 | 
			
		||||
	postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -272,7 +272,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
 | 
			
		||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
 | 
			
		||||
	usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
 | 
			
		||||
	modelPrice float64, usePrice bool) {
 | 
			
		||||
 | 
			
		||||
@@ -281,7 +281,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
 | 
			
		||||
	completionTokens := usage.CompletionTokens
 | 
			
		||||
 | 
			
		||||
	tokenName := ctx.GetString("token_name")
 | 
			
		||||
	completionRatio := common.GetCompletionRatio(textRequest.Model)
 | 
			
		||||
	completionRatio := common.GetCompletionRatio(modelName)
 | 
			
		||||
 | 
			
		||||
	quota := 0
 | 
			
		||||
	if !usePrice {
 | 
			
		||||
@@ -307,7 +307,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
 | 
			
		||||
		// we cannot just return, because we may have to return the pre-consumed quota
 | 
			
		||||
		quota = 0
 | 
			
		||||
		logContent += fmt.Sprintf("(可能是上游超时)")
 | 
			
		||||
		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
 | 
			
		||||
		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
 | 
			
		||||
			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
 | 
			
		||||
	} else {
 | 
			
		||||
		//if sensitiveResp != nil {
 | 
			
		||||
		//	logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
 | 
			
		||||
@@ -327,13 +328,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
 | 
			
		||||
		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logModel := textRequest.Model
 | 
			
		||||
	logModel := modelName
 | 
			
		||||
	if strings.HasPrefix(logModel, "gpt-4-gizmo") {
 | 
			
		||||
		logModel = "gpt-4-gizmo-*"
 | 
			
		||||
		logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
 | 
			
		||||
		logContent += fmt.Sprintf(",模型 %s", modelName)
 | 
			
		||||
	}
 | 
			
		||||
	other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
 | 
			
		||||
	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
 | 
			
		||||
	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
 | 
			
		||||
		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
 | 
			
		||||
 | 
			
		||||
	//if quota != 0 {
 | 
			
		||||
	//
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										104
									
								
								relay/relay_rerank.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								relay/relay_rerank.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,104 @@
 | 
			
		||||
package relay
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
 | 
			
		||||
	token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
 | 
			
		||||
	for _, document := range rerankRequest.Documents {
 | 
			
		||||
		tkm, err := service.CountTokenInput(document, rerankRequest.Model)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			token += tkm
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return token
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
	relayInfo := relaycommon.GenRelayInfo(c)
 | 
			
		||||
 | 
			
		||||
	var rerankRequest *dto.RerankRequest
 | 
			
		||||
	err := common.UnmarshalBodyReusable(c, &rerankRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
 | 
			
		||||
		return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	if rerankRequest.Query == "" {
 | 
			
		||||
		return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	if len(rerankRequest.Documents) == 0 {
 | 
			
		||||
		return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	relayInfo.UpstreamModelName = rerankRequest.Model
 | 
			
		||||
	modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
 | 
			
		||||
	groupRatio := common.GetGroupRatio(relayInfo.Group)
 | 
			
		||||
 | 
			
		||||
	var preConsumedQuota int
 | 
			
		||||
	var ratio float64
 | 
			
		||||
	var modelRatio float64
 | 
			
		||||
 | 
			
		||||
	promptToken := getRerankPromptToken(*rerankRequest)
 | 
			
		||||
	if !success {
 | 
			
		||||
		preConsumedTokens := promptToken
 | 
			
		||||
		modelRatio = common.GetModelRatio(rerankRequest.Model)
 | 
			
		||||
		ratio = modelRatio * groupRatio
 | 
			
		||||
		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
 | 
			
		||||
	} else {
 | 
			
		||||
		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
 | 
			
		||||
	}
 | 
			
		||||
	relayInfo.PromptTokens = promptToken
 | 
			
		||||
 | 
			
		||||
	// pre-consume quota 预消耗配额
 | 
			
		||||
	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
 | 
			
		||||
	if openaiErr != nil {
 | 
			
		||||
		return openaiErr
 | 
			
		||||
	}
 | 
			
		||||
	adaptor := GetAdaptor(relayInfo.ApiType)
 | 
			
		||||
	if adaptor == nil {
 | 
			
		||||
		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	adaptor.InitRerank(relayInfo, *rerankRequest)
 | 
			
		||||
 | 
			
		||||
	convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	jsonData, err := json.Marshal(convertedRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	requestBody := bytes.NewBuffer(jsonData)
 | 
			
		||||
	statusCodeMappingStr := c.GetString("status_code_mapping")
 | 
			
		||||
	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	if resp != nil {
 | 
			
		||||
		if resp.StatusCode != http.StatusOK {
 | 
			
		||||
			returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
 | 
			
		||||
			openaiErr := service.RelayErrorHandler(resp)
 | 
			
		||||
			// reset status code 重置状态码
 | 
			
		||||
			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 | 
			
		||||
			return openaiErr
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
 | 
			
		||||
	if openaiErr != nil {
 | 
			
		||||
		returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
 | 
			
		||||
		// reset status code 重置状态码
 | 
			
		||||
		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 | 
			
		||||
		return openaiErr
 | 
			
		||||
	}
 | 
			
		||||
	postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -42,6 +42,7 @@ func SetRelayRouter(router *gin.Engine) {
 | 
			
		||||
		relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/moderations", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/rerank", controller.Relay)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	relayMjRouter := router.Group("/mj")
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user