mirror of
				https://github.com/linux-do/new-api.git
				synced 2025-11-04 13:23:42 +08:00 
			
		
		
		
	feat: support baidu's embedding model (close #324)
This commit is contained in:
		@@ -42,6 +42,7 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"claude-2":                30,
 | 
			
		||||
	"ERNIE-Bot":               0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"ERNIE-Bot-turbo":         0.5715, // ¥0.008 / 1k tokens
 | 
			
		||||
	"Embedding-V1":            0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"PaLM-2":                  1,
 | 
			
		||||
	"chatglm_pro":             0.7143, // ¥0.01 / 1k tokens
 | 
			
		||||
	"chatglm_std":             0.3572, // ¥0.005 / 1k tokens
 | 
			
		||||
 
 | 
			
		||||
@@ -288,6 +288,15 @@ func init() {
 | 
			
		||||
			Root:       "ERNIE-Bot-turbo",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "Embedding-V1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "baidu",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "Embedding-V1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "PaLM-2",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
 
 | 
			
		||||
@@ -54,6 +54,25 @@ type BaiduChatStreamResponse struct {
 | 
			
		||||
	IsEnd      bool `json:"is_end"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BaiduEmbeddingRequest struct {
 | 
			
		||||
	Input []string `json:"input"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BaiduEmbeddingData struct {
 | 
			
		||||
	Object    string    `json:"object"`
 | 
			
		||||
	Embedding []float64 `json:"embedding"`
 | 
			
		||||
	Index     int       `json:"index"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BaiduEmbeddingResponse struct {
 | 
			
		||||
	Id      string               `json:"id"`
 | 
			
		||||
	Object  string               `json:"object"`
 | 
			
		||||
	Created int64                `json:"created"`
 | 
			
		||||
	Data    []BaiduEmbeddingData `json:"data"`
 | 
			
		||||
	Usage   Usage                `json:"usage"`
 | 
			
		||||
	BaiduError
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 | 
			
		||||
	messages := make([]BaiduMessage, 0, len(request.Messages))
 | 
			
		||||
	for _, message := range request.Messages {
 | 
			
		||||
@@ -112,6 +131,36 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
 | 
			
		||||
	return &response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
 | 
			
		||||
	baiduEmbeddingRequest := BaiduEmbeddingRequest{
 | 
			
		||||
		Input: nil,
 | 
			
		||||
	}
 | 
			
		||||
	switch request.Input.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		baiduEmbeddingRequest.Input = []string{request.Input.(string)}
 | 
			
		||||
	case []string:
 | 
			
		||||
		baiduEmbeddingRequest.Input = request.Input.([]string)
 | 
			
		||||
	}
 | 
			
		||||
	return &baiduEmbeddingRequest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
 | 
			
		||||
	openAIEmbeddingResponse := OpenAIEmbeddingResponse{
 | 
			
		||||
		Object: "list",
 | 
			
		||||
		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
 | 
			
		||||
		Model:  "baidu-embedding",
 | 
			
		||||
		Usage:  response.Usage,
 | 
			
		||||
	}
 | 
			
		||||
	for _, item := range response.Data {
 | 
			
		||||
		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
 | 
			
		||||
			Object:    item.Object,
 | 
			
		||||
			Index:     item.Index,
 | 
			
		||||
			Embedding: item.Embedding,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	return &openAIEmbeddingResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var usage Usage
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
@@ -212,3 +261,39 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var baiduResponse BaiduEmbeddingResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &baiduResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	if baiduResponse.ErrorMsg != "" {
 | 
			
		||||
		return &OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: OpenAIError{
 | 
			
		||||
				Message: baiduResponse.ErrorMsg,
 | 
			
		||||
				Type:    "baidu_error",
 | 
			
		||||
				Param:   "",
 | 
			
		||||
				Code:    baiduResponse.ErrorCode,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -139,6 +139,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
 | 
			
		||||
		case "BLOOMZ-7B":
 | 
			
		||||
			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
 | 
			
		||||
		case "Embedding-V1":
 | 
			
		||||
			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
 | 
			
		||||
		}
 | 
			
		||||
		apiKey := c.Request.Header.Get("Authorization")
 | 
			
		||||
		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
			
		||||
@@ -212,12 +214,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	case APITypeBaidu:
 | 
			
		||||
		baiduRequest := requestOpenAI2Baidu(textRequest)
 | 
			
		||||
		jsonStr, err := json.Marshal(baiduRequest)
 | 
			
		||||
		var jsonData []byte
 | 
			
		||||
		var err error
 | 
			
		||||
		switch relayMode {
 | 
			
		||||
		case RelayModeEmbeddings:
 | 
			
		||||
			baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
 | 
			
		||||
			jsonData, err = json.Marshal(baiduEmbeddingRequest)
 | 
			
		||||
		default:
 | 
			
		||||
			baiduRequest := requestOpenAI2Baidu(textRequest)
 | 
			
		||||
			jsonData, err = json.Marshal(baiduRequest)
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonData)
 | 
			
		||||
	case APITypePaLM:
 | 
			
		||||
		palmRequest := requestOpenAI2PaLM(textRequest)
 | 
			
		||||
		jsonStr, err := json.Marshal(palmRequest)
 | 
			
		||||
@@ -386,7 +396,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := baiduHandler(c, resp)
 | 
			
		||||
			var err *OpenAIErrorWithStatusCode
 | 
			
		||||
			var usage *Usage
 | 
			
		||||
			switch relayMode {
 | 
			
		||||
			case RelayModeEmbeddings:
 | 
			
		||||
				err, usage = baiduEmbeddingHandler(c, resp)
 | 
			
		||||
			default:
 | 
			
		||||
				err, usage = baiduHandler(c, resp)
 | 
			
		||||
			}
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
 
 | 
			
		||||
@@ -99,6 +99,19 @@ type OpenAITextResponse struct {
 | 
			
		||||
	Usage   `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIEmbeddingResponseItem struct {
 | 
			
		||||
	Object    string    `json:"object"`
 | 
			
		||||
	Index     int       `json:"index"`
 | 
			
		||||
	Embedding []float64 `json:"embedding"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIEmbeddingResponse struct {
 | 
			
		||||
	Object string                        `json:"object"`
 | 
			
		||||
	Data   []OpenAIEmbeddingResponseItem `json:"data"`
 | 
			
		||||
	Model  string                        `json:"model"`
 | 
			
		||||
	Usage  `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ImageResponse struct {
 | 
			
		||||
	Created int `json:"created"`
 | 
			
		||||
	Data    []struct {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user