mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	feat: support hunyuan-embedding (#2035)
* feat: support hunyuan-embedding * chore: improve implementation --------- Co-authored-by: LUO Feng <luofeng@flowpp.com> Co-authored-by: JustSong <quanpengsong@gmail.com>
This commit is contained in:
		@@ -2,16 +2,19 @@ package tencent
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/relaymode"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://cloud.tencent.com/document/api/1729/101837
 | 
			
		||||
@@ -52,10 +55,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	tencentRequest := ConvertRequest(*request)
 | 
			
		||||
	var convertedRequest any
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case relaymode.Embeddings:
 | 
			
		||||
		a.Action = "GetEmbedding"
 | 
			
		||||
		convertedRequest = ConvertEmbeddingRequest(*request)
 | 
			
		||||
	default:
 | 
			
		||||
		a.Action = "ChatCompletions"
 | 
			
		||||
		convertedRequest = ConvertRequest(*request)
 | 
			
		||||
	}
 | 
			
		||||
	// we have to calculate the sign here
 | 
			
		||||
	a.Sign = GetSign(*tencentRequest, a, secretId, secretKey)
 | 
			
		||||
	return tencentRequest, nil
 | 
			
		||||
	a.Sign = GetSign(convertedRequest, a, secretId, secretKey)
 | 
			
		||||
	return convertedRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
@@ -75,8 +86,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
 | 
			
		||||
		err, responseText = StreamHandler(c, resp)
 | 
			
		||||
		usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
 | 
			
		||||
	} else {
 | 
			
		||||
		switch meta.Mode {
 | 
			
		||||
		case relaymode.Embeddings:
 | 
			
		||||
			err, usage = EmbeddingHandler(c, resp)
 | 
			
		||||
		default:
 | 
			
		||||
			err, usage = Handler(c, resp)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -6,4 +6,5 @@ var ModelList = []string{
 | 
			
		||||
	"hunyuan-standard-256K",
 | 
			
		||||
	"hunyuan-pro",
 | 
			
		||||
	"hunyuan-vision",
 | 
			
		||||
	"hunyuan-embedding",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -8,7 +8,6 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@@ -16,11 +15,14 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/conv"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
@@ -44,8 +46,68 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
 | 
			
		||||
	return &EmbeddingRequest{
 | 
			
		||||
		InputList: request.ParseInput(),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	var tencentResponseP EmbeddingResponseP
 | 
			
		||||
	err := json.NewDecoder(resp.Body).Decode(&tencentResponseP)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tencentResponse := tencentResponseP.Response
 | 
			
		||||
	if tencentResponse.Error.Code != "" {
 | 
			
		||||
		return &model.ErrorWithStatusCode{
 | 
			
		||||
			Error: model.Error{
 | 
			
		||||
				Message: tencentResponse.Error.Message,
 | 
			
		||||
				Code:    tencentResponse.Error.Code,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	requestModel := c.GetString(ctxkey.RequestModel)
 | 
			
		||||
	fullTextResponse := embeddingResponseTencent2OpenAI(&tencentResponse)
 | 
			
		||||
	fullTextResponse.Model = requestModel
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func embeddingResponseTencent2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
 | 
			
		||||
	openAIEmbeddingResponse := openai.EmbeddingResponse{
 | 
			
		||||
		Object: "list",
 | 
			
		||||
		Data:   make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
 | 
			
		||||
		Model:  "hunyuan-embedding",
 | 
			
		||||
		Usage:  model.Usage{TotalTokens: response.EmbeddingUsage.TotalTokens},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, item := range response.Data {
 | 
			
		||||
		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
 | 
			
		||||
			Object:    item.Object,
 | 
			
		||||
			Index:     item.Index,
 | 
			
		||||
			Embedding: item.Embedding,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	return &openAIEmbeddingResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
 | 
			
		||||
	fullTextResponse := openai.TextResponse{
 | 
			
		||||
		Id:      response.ReqID,
 | 
			
		||||
		Object:  "chat.completion",
 | 
			
		||||
		Created: helper.GetTimestamp(),
 | 
			
		||||
		Usage: model.Usage{
 | 
			
		||||
@@ -148,7 +210,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	TencentResponse = responseP.Response
 | 
			
		||||
	if TencentResponse.Error.Code != 0 {
 | 
			
		||||
	if TencentResponse.Error.Code != "" {
 | 
			
		||||
		return &model.ErrorWithStatusCode{
 | 
			
		||||
			Error: model.Error{
 | 
			
		||||
				Message: TencentResponse.Error.Message,
 | 
			
		||||
@@ -195,7 +257,7 @@ func hmacSha256(s, key string) string {
 | 
			
		||||
	return string(hashed.Sum(nil))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string {
 | 
			
		||||
func GetSign(req any, adaptor *Adaptor, secId, secKey string) string {
 | 
			
		||||
	// build canonical request string
 | 
			
		||||
	host := "hunyuan.tencentcloudapi.com"
 | 
			
		||||
	httpRequestMethod := "POST"
 | 
			
		||||
 
 | 
			
		||||
@@ -35,16 +35,16 @@ type ChatRequest struct {
 | 
			
		||||
	// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
 | 
			
		||||
	// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
 | 
			
		||||
	// 3. 非必要不建议使用,不合理的取值会影响效果。
 | 
			
		||||
	TopP *float64 `json:"TopP"`
 | 
			
		||||
	TopP *float64 `json:"TopP,omitempty"`
 | 
			
		||||
	// 说明:
 | 
			
		||||
	// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
 | 
			
		||||
	// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
 | 
			
		||||
	// 3. 非必要不建议使用,不合理的取值会影响效果。
 | 
			
		||||
	Temperature *float64 `json:"Temperature"`
 | 
			
		||||
	Temperature *float64 `json:"Temperature,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Error struct {
 | 
			
		||||
	Code    int    `json:"Code"`
 | 
			
		||||
	Code    string `json:"Code"`
 | 
			
		||||
	Message string `json:"Message"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -67,9 +67,35 @@ type ChatResponse struct {
 | 
			
		||||
	Usage   Usage             `json:"Usage,omitempty"`     // token 数量
 | 
			
		||||
	Error   Error             `json:"Error,omitempty"`     // 错误信息 注意:此字段可能返回 null,表示取不到有效值
 | 
			
		||||
	Note    string            `json:"Note,omitempty"`      // 注释
 | 
			
		||||
	ReqID   string            `json:"Req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
 | 
			
		||||
	ReqID   string            `json:"RequestId,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatResponseP struct {
 | 
			
		||||
	Response ChatResponse `json:"Response,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddingRequest struct {
 | 
			
		||||
	InputList []string `json:"InputList"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddingData struct {
 | 
			
		||||
	Embedding []float64 `json:"Embedding"`
 | 
			
		||||
	Index     int       `json:"Index"`
 | 
			
		||||
	Object    string    `json:"Object"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddingUsage struct {
 | 
			
		||||
	PromptTokens int `json:"PromptTokens"`
 | 
			
		||||
	TotalTokens  int `json:"TotalTokens"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddingResponse struct {
 | 
			
		||||
	Data           []EmbeddingData `json:"Data"`
 | 
			
		||||
	EmbeddingUsage EmbeddingUsage  `json:"Usage,omitempty"`
 | 
			
		||||
	RequestId      string          `json:"RequestId,omitempty"`
 | 
			
		||||
	Error          Error           `json:"Error,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EmbeddingResponseP struct {
 | 
			
		||||
	Response EmbeddingResponse `json:"Response,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user