mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	feat: add embedding-2 support for zhipu (#1273)
* 增加对智谱embedding-2模型的支持 * fix: fix usage & ratio --------- Co-authored-by: yangfei <yangfei@xuyao.info> Co-authored-by: JustSong <songquanpeng@foxmail.com>
This commit is contained in:
		@@ -91,6 +91,7 @@ var ModelRatio = map[string]float64{
 | 
				
			|||||||
	"glm-4":                     0.1 * RMB,
 | 
						"glm-4":                     0.1 * RMB,
 | 
				
			||||||
	"glm-4v":                    0.1 * RMB,
 | 
						"glm-4v":                    0.1 * RMB,
 | 
				
			||||||
	"glm-3-turbo":               0.005 * RMB,
 | 
						"glm-3-turbo":               0.005 * RMB,
 | 
				
			||||||
 | 
						"embedding-2":               0.0005 * RMB,
 | 
				
			||||||
	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens
 | 
						"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens
 | 
				
			||||||
	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens
 | 
						"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens
 | 
				
			||||||
	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens
 | 
						"chatglm_std":               0.3572, // ¥0.005 / 1k tokens
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,6 +6,7 @@ import (
 | 
				
			|||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/songquanpeng/one-api/relay/channel"
 | 
						"github.com/songquanpeng/one-api/relay/channel"
 | 
				
			||||||
	"github.com/songquanpeng/one-api/relay/channel/openai"
 | 
						"github.com/songquanpeng/one-api/relay/channel/openai"
 | 
				
			||||||
 | 
						"github.com/songquanpeng/one-api/relay/constant"
 | 
				
			||||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
						"github.com/songquanpeng/one-api/relay/model"
 | 
				
			||||||
	"github.com/songquanpeng/one-api/relay/util"
 | 
						"github.com/songquanpeng/one-api/relay/util"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
@@ -35,6 +36,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
 | 
				
			|||||||
	if a.APIVersion == "v4" {
 | 
						if a.APIVersion == "v4" {
 | 
				
			||||||
		return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
 | 
							return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						if meta.Mode == constant.RelayModeEmbeddings {
 | 
				
			||||||
 | 
							return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	method := "invoke"
 | 
						method := "invoke"
 | 
				
			||||||
	if meta.IsStream {
 | 
						if meta.IsStream {
 | 
				
			||||||
		method = "sse-invoke"
 | 
							method = "sse-invoke"
 | 
				
			||||||
@@ -53,18 +57,24 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
				
			|||||||
	if request == nil {
 | 
						if request == nil {
 | 
				
			||||||
		return nil, errors.New("request is nil")
 | 
							return nil, errors.New("request is nil")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	// TopP (0.0, 1.0)
 | 
						switch relayMode {
 | 
				
			||||||
	request.TopP = math.Min(0.99, request.TopP)
 | 
						case constant.RelayModeEmbeddings:
 | 
				
			||||||
	request.TopP = math.Max(0.01, request.TopP)
 | 
							baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
 | 
				
			||||||
 | 
							return baiduEmbeddingRequest, nil
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							// TopP (0.0, 1.0)
 | 
				
			||||||
 | 
							request.TopP = math.Min(0.99, request.TopP)
 | 
				
			||||||
 | 
							request.TopP = math.Max(0.01, request.TopP)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Temperature (0.0, 1.0)
 | 
							// Temperature (0.0, 1.0)
 | 
				
			||||||
	request.Temperature = math.Min(0.99, request.Temperature)
 | 
							request.Temperature = math.Min(0.99, request.Temperature)
 | 
				
			||||||
	request.Temperature = math.Max(0.01, request.Temperature)
 | 
							request.Temperature = math.Max(0.01, request.Temperature)
 | 
				
			||||||
	a.SetVersionByModeName(request.Model)
 | 
							a.SetVersionByModeName(request.Model)
 | 
				
			||||||
	if a.APIVersion == "v4" {
 | 
							if a.APIVersion == "v4" {
 | 
				
			||||||
		return request, nil
 | 
								return request, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return ConvertRequest(*request), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return ConvertRequest(*request), nil
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
 | 
					func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
 | 
				
			||||||
@@ -84,14 +94,26 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
 | 
				
			|||||||
	if a.APIVersion == "v4" {
 | 
						if a.APIVersion == "v4" {
 | 
				
			||||||
		return a.DoResponseV4(c, resp, meta)
 | 
							return a.DoResponseV4(c, resp, meta)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if meta.IsStream {
 | 
						if meta.IsStream {
 | 
				
			||||||
		err, usage = StreamHandler(c, resp)
 | 
							err, usage = StreamHandler(c, resp)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		err, usage = Handler(c, resp)
 | 
							if meta.Mode == constant.RelayModeEmbeddings {
 | 
				
			||||||
 | 
								err, usage = EmbeddingsHandler(c, resp)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								err, usage = Handler(c, resp)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
 | 
				
			||||||
 | 
						return &EmbeddingRequest{
 | 
				
			||||||
 | 
							Model: "embedding-2",
 | 
				
			||||||
 | 
							Input: request.Input.(string),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (a *Adaptor) GetModelList() []string {
 | 
					func (a *Adaptor) GetModelList() []string {
 | 
				
			||||||
	return ModelList
 | 
						return ModelList
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,5 +2,5 @@ package zhipu
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
var ModelList = []string{
 | 
					var ModelList = []string{
 | 
				
			||||||
	"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
 | 
						"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
 | 
				
			||||||
	"glm-4", "glm-4v", "glm-3-turbo",
 | 
						"glm-4", "glm-4v", "glm-3-turbo", "embedding-2",
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -254,3 +254,50 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
 | 
				
			|||||||
	_, err = c.Writer.Write(jsonResponse)
 | 
						_, err = c.Writer.Write(jsonResponse)
 | 
				
			||||||
	return nil, &fullTextResponse.Usage
 | 
						return nil, &fullTextResponse.Usage
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
				
			||||||
 | 
						var zhipuResponse EmbeddingRespone
 | 
				
			||||||
 | 
						responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = resp.Body.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = json.Unmarshal(responseBody, &zhipuResponse)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse)
 | 
				
			||||||
 | 
						jsonResponse, err := json.Marshal(fullTextResponse)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						c.Writer.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
 | 
						c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
 | 
						_, err = c.Writer.Write(jsonResponse)
 | 
				
			||||||
 | 
						return nil, &fullTextResponse.Usage
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func embeddingResponseZhipu2OpenAI(response *EmbeddingRespone) *openai.EmbeddingResponse {
 | 
				
			||||||
 | 
						openAIEmbeddingResponse := openai.EmbeddingResponse{
 | 
				
			||||||
 | 
							Object: "list",
 | 
				
			||||||
 | 
							Data:   make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
 | 
				
			||||||
 | 
							Model:  response.Model,
 | 
				
			||||||
 | 
							Usage: model.Usage{
 | 
				
			||||||
 | 
								PromptTokens:     response.PromptTokens,
 | 
				
			||||||
 | 
								CompletionTokens: response.CompletionTokens,
 | 
				
			||||||
 | 
								TotalTokens:      response.Usage.TotalTokens,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, item := range response.Embeddings {
 | 
				
			||||||
 | 
							openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
 | 
				
			||||||
 | 
								Object:    `embedding`,
 | 
				
			||||||
 | 
								Index:     item.Index,
 | 
				
			||||||
 | 
								Embedding: item.Embedding,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &openAIEmbeddingResponse
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -44,3 +44,21 @@ type tokenData struct {
 | 
				
			|||||||
	Token      string
 | 
						Token      string
 | 
				
			||||||
	ExpiryTime time.Time
 | 
						ExpiryTime time.Time
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type EmbeddingRequest struct {
 | 
				
			||||||
 | 
						Model string `json:"model"`
 | 
				
			||||||
 | 
						Input string `json:"input"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type EmbeddingRespone struct {
 | 
				
			||||||
 | 
						Model       string          `json:"model"`
 | 
				
			||||||
 | 
						Object      string          `json:"object"`
 | 
				
			||||||
 | 
						Embeddings  []EmbeddingData `json:"data"`
 | 
				
			||||||
 | 
						model.Usage `json:"usage"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type EmbeddingData struct {
 | 
				
			||||||
 | 
						Index     int       `json:"index"`
 | 
				
			||||||
 | 
						Object    string    `json:"object"`
 | 
				
			||||||
 | 
						Embedding []float64 `json:"embedding"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user