From bccdcca7cbec8045524f781e20193cf3616c3add Mon Sep 17 00:00:00 2001 From: LUO Feng Date: Wed, 15 Jan 2025 11:58:59 +0800 Subject: [PATCH] feat: support hunyuan-embedding --- relay/adaptor/tencent/adaptor.go | 36 +++++++++++++--- relay/adaptor/tencent/constants.go | 1 + relay/adaptor/tencent/main.go | 66 ++++++++++++++++++++++++++++-- relay/adaptor/tencent/model.go | 46 ++++++++++++++++----- 4 files changed, 131 insertions(+), 18 deletions(-) diff --git a/relay/adaptor/tencent/adaptor.go b/relay/adaptor/tencent/adaptor.go index 0de92d4a..a98f63c5 100644 --- a/relay/adaptor/tencent/adaptor.go +++ b/relay/adaptor/tencent/adaptor.go @@ -1,6 +1,7 @@ package tencent import ( + "encoding/json" "errors" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/helper" @@ -8,6 +9,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" "strconv" @@ -52,10 +54,29 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if err != nil { return nil, err } - tencentRequest := ConvertRequest(*request) - // we have to calculate the sign here - a.Sign = GetSign(*tencentRequest, a, secretId, secretKey) - return tencentRequest, nil + + switch relayMode { + case relaymode.Embeddings: + a.Action = "GetEmbedding" + tencentEmbeddingRequest := ConvertEmbeddingRequest(*request) + payload, err := json.Marshal(tencentEmbeddingRequest) + if err != nil { + return nil, err + } + // we have to calculate the sign here + a.Sign = GetSign(payload, a, secretId, secretKey) + return tencentEmbeddingRequest, nil + default: + a.Action = "ChatCompletions" + tencentRequest := ConvertRequest(*request) + payload, err := json.Marshal(tencentRequest) + if err != nil { + return nil, err + } + // we have to calculate the sign here + a.Sign = GetSign(payload, a, secretId, secretKey) + return tencentRequest, nil + } } func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { @@ -75,7 +96,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met err, responseText = StreamHandler(c, resp) usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) } else { - err, usage = Handler(c, resp) + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingHandler(c, resp) + default: + err, usage = Handler(c, resp) + } } return } diff --git a/relay/adaptor/tencent/constants.go b/relay/adaptor/tencent/constants.go index e8631e5f..7997bfd6 100644 --- a/relay/adaptor/tencent/constants.go +++ b/relay/adaptor/tencent/constants.go @@ -6,4 +6,5 @@ var ModelList = []string{ "hunyuan-standard-256K", "hunyuan-pro", "hunyuan-vision", + "hunyuan-embedding", } diff --git a/relay/adaptor/tencent/main.go b/relay/adaptor/tencent/main.go index 827c8a46..05b49f89 100644 --- a/relay/adaptor/tencent/main.go +++ b/relay/adaptor/tencent/main.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/render" "io" "net/http" @@ -44,8 +45,68 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { } } +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ + InputList: request.ParseInput(), + } +} + +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var tencentResponseP EmbeddingResponseP + err := json.NewDecoder(resp.Body).Decode(&tencentResponseP) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + tencentResponse := tencentResponseP.Response + if tencentResponse.Error.Code != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: tencentResponse.Error.Message, + Code: tencentResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + requestModel := c.GetString(ctxkey.RequestModel) + fullTextResponse := embeddingResponseTencent2OpenAI(&tencentResponse) + fullTextResponse.Model = requestModel + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseTencent2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)), + Model: "hunyuan-embedding", + Usage: model.Usage{TotalTokens: response.EmbeddingUsage.TotalTokens}, + } + + for _, item := range response.Data { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: item.Object, + Index: item.Index, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} + func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ + Id: response.ReqID, Object: "chat.completion", Created: helper.GetTimestamp(), Usage: model.Usage{ @@ -148,7 +209,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } TencentResponse = responseP.Response - if TencentResponse.Error.Code != 0 { + if TencentResponse.Error.Code != "" { return &model.ErrorWithStatusCode{ Error: model.Error{ Message: TencentResponse.Error.Message, @@ -195,7 +256,7 @@ func hmacSha256(s, key string) string { return string(hashed.Sum(nil)) } -func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string { +func GetSign(payload []byte, adaptor *Adaptor, secId, secKey string) string { // build canonical request string host := "hunyuan.tencentcloudapi.com" httpRequestMethod := "POST" @@ -204,7 +265,6 @@ func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string { canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n", "application/json", host, strings.ToLower(adaptor.Action)) signedHeaders := "content-type;host;x-tc-action" - payload, _ := json.Marshal(req) hashedRequestPayload := sha256hex(string(payload)) canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", httpRequestMethod, diff --git a/relay/adaptor/tencent/model.go b/relay/adaptor/tencent/model.go index fb97724e..fda6c6cc 100644 --- a/relay/adaptor/tencent/model.go +++ b/relay/adaptor/tencent/model.go @@ -35,16 +35,16 @@ type ChatRequest struct { // 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。 // 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。 // 3. 非必要不建议使用,不合理的取值会影响效果。 - TopP *float64 `json:"TopP"` + TopP *float64 `json:"TopP,omitempty"` // 说明: // 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。 // 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。 // 3. 非必要不建议使用,不合理的取值会影响效果。 - Temperature *float64 `json:"Temperature"` + Temperature *float64 `json:"Temperature,omitempty"` } type Error struct { - Code int `json:"Code"` + Code string `json:"Code"` Message string `json:"Message"` } @@ -61,15 +61,41 @@ type ResponseChoices struct { } type ChatResponse struct { - Choices []ResponseChoices `json:"Choices,omitempty"` // 结果 - Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串 - Id string `json:"Id,omitempty"` // 会话 id - Usage Usage `json:"Usage,omitempty"` // token 数量 - Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 - Note string `json:"Note,omitempty"` // 注释 - ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 + Choices []ResponseChoices `json:"Choices,omitempty"` // 结果 + Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串 + Id string `json:"Id,omitempty"` // 会话 id + Usage Usage `json:"Usage,omitempty"` // token 数量 + Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 + Note string `json:"Note,omitempty"` // 注释 + ReqID string `json:"RequestId,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 } type ChatResponseP struct { Response ChatResponse `json:"Response,omitempty"` } + +type EmbeddingRequest struct { + InputList []string `json:"InputList"` +} + +type EmbeddingData struct { + Embedding []float64 `json:"Embedding"` + Index int `json:"Index"` + Object string `json:"Object"` +} + +type EmbeddingUsage struct { + PromptTokens int `json:"PromptTokens"` + TotalTokens int `json:"TotalTokens"` +} + +type EmbeddingResponse struct { + Data []EmbeddingData `json:"Data"` + EmbeddingUsage EmbeddingUsage `json:"Usage,omitempty"` + RequestId string `json:"RequestId,omitempty"` + Error Error `json:"Error,omitempty"` +} + +type EmbeddingResponseP struct { + Response EmbeddingResponse `json:"Response,omitempty"` +}