mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 09:16:36 +08:00
feat: support hunyuan-embedding
This commit is contained in:
parent
3915ce9814
commit
bccdcca7cb
@ -1,6 +1,7 @@
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
@ -8,6 +9,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@ -52,10 +54,29 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tencentRequest := ConvertRequest(*request)
|
||||
// we have to calculate the sign here
|
||||
a.Sign = GetSign(*tencentRequest, a, secretId, secretKey)
|
||||
return tencentRequest, nil
|
||||
|
||||
switch relayMode {
|
||||
case relaymode.Embeddings:
|
||||
a.Action = "GetEmbedding"
|
||||
tencentEmbeddingRequest := ConvertEmbeddingRequest(*request)
|
||||
payload, err := json.Marshal(tencentEmbeddingRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// we have to calculate the sign here
|
||||
a.Sign = GetSign(payload, a, secretId, secretKey)
|
||||
return tencentEmbeddingRequest, nil
|
||||
default:
|
||||
a.Action = "ChatCompletions"
|
||||
tencentRequest := ConvertRequest(*request)
|
||||
payload, err := json.Marshal(tencentRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// we have to calculate the sign here
|
||||
a.Sign = GetSign(payload, a, secretId, secretKey)
|
||||
return tencentRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
@ -75,7 +96,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
} else {
|
||||
err, usage = Handler(c, resp)
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -6,4 +6,5 @@ var ModelList = []string{
|
||||
"hunyuan-standard-256K",
|
||||
"hunyuan-pro",
|
||||
"hunyuan-vision",
|
||||
"hunyuan-embedding",
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -44,8 +45,68 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
InputList: request.ParseInput(),
|
||||
}
|
||||
}
|
||||
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var tencentResponseP EmbeddingResponseP
|
||||
err := json.NewDecoder(resp.Body).Decode(&tencentResponseP)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
tencentResponse := tencentResponseP.Response
|
||||
if tencentResponse.Error.Code != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: tencentResponse.Error.Message,
|
||||
Code: tencentResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
requestModel := c.GetString(ctxkey.RequestModel)
|
||||
fullTextResponse := embeddingResponseTencent2OpenAI(&tencentResponse)
|
||||
fullTextResponse.Model = requestModel
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func embeddingResponseTencent2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
|
||||
Model: "hunyuan-embedding",
|
||||
Usage: model.Usage{TotalTokens: response.EmbeddingUsage.TotalTokens},
|
||||
}
|
||||
|
||||
for _, item := range response.Data {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: item.Object,
|
||||
Index: item.Index,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: response.ReqID,
|
||||
Object: "chat.completion",
|
||||
Created: helper.GetTimestamp(),
|
||||
Usage: model.Usage{
|
||||
@ -148,7 +209,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
TencentResponse = responseP.Response
|
||||
if TencentResponse.Error.Code != 0 {
|
||||
if TencentResponse.Error.Code != "" {
|
||||
return &model.ErrorWithStatusCode{
|
||||
Error: model.Error{
|
||||
Message: TencentResponse.Error.Message,
|
||||
@ -195,7 +256,7 @@ func hmacSha256(s, key string) string {
|
||||
return string(hashed.Sum(nil))
|
||||
}
|
||||
|
||||
func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string {
|
||||
func GetSign(payload []byte, adaptor *Adaptor, secId, secKey string) string {
|
||||
// build canonical request string
|
||||
host := "hunyuan.tencentcloudapi.com"
|
||||
httpRequestMethod := "POST"
|
||||
@ -204,7 +265,6 @@ func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string {
|
||||
canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
|
||||
"application/json", host, strings.ToLower(adaptor.Action))
|
||||
signedHeaders := "content-type;host;x-tc-action"
|
||||
payload, _ := json.Marshal(req)
|
||||
hashedRequestPayload := sha256hex(string(payload))
|
||||
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
||||
httpRequestMethod,
|
||||
|
@ -35,16 +35,16 @@ type ChatRequest struct {
|
||||
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
|
||||
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
|
||||
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||
TopP *float64 `json:"TopP"`
|
||||
TopP *float64 `json:"TopP,omitempty"`
|
||||
// 说明:
|
||||
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
|
||||
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
|
||||
// 3. 非必要不建议使用,不合理的取值会影响效果。
|
||||
Temperature *float64 `json:"Temperature"`
|
||||
Temperature *float64 `json:"Temperature,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code int `json:"Code"`
|
||||
Code string `json:"Code"`
|
||||
Message string `json:"Message"`
|
||||
}
|
||||
|
||||
@ -61,15 +61,41 @@ type ResponseChoices struct {
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Choices []ResponseChoices `json:"Choices,omitempty"` // 结果
|
||||
Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"Id,omitempty"` // 会话 id
|
||||
Usage Usage `json:"Usage,omitempty"` // token 数量
|
||||
Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"Note,omitempty"` // 注释
|
||||
ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
Choices []ResponseChoices `json:"Choices,omitempty"` // 结果
|
||||
Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"Id,omitempty"` // 会话 id
|
||||
Usage Usage `json:"Usage,omitempty"` // token 数量
|
||||
Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"Note,omitempty"` // 注释
|
||||
ReqID string `json:"RequestId,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
}
|
||||
|
||||
type ChatResponseP struct {
|
||||
Response ChatResponse `json:"Response,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
InputList []string `json:"InputList"`
|
||||
}
|
||||
|
||||
type EmbeddingData struct {
|
||||
Embedding []float64 `json:"Embedding"`
|
||||
Index int `json:"Index"`
|
||||
Object string `json:"Object"`
|
||||
}
|
||||
|
||||
type EmbeddingUsage struct {
|
||||
PromptTokens int `json:"PromptTokens"`
|
||||
TotalTokens int `json:"TotalTokens"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Data []EmbeddingData `json:"Data"`
|
||||
EmbeddingUsage EmbeddingUsage `json:"Usage,omitempty"`
|
||||
RequestId string `json:"RequestId,omitempty"`
|
||||
Error Error `json:"Error,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingResponseP struct {
|
||||
Response EmbeddingResponse `json:"Response,omitempty"`
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user