From 6c5b3b51b079a388ce8218905895e0c2c6bb6f7b Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 5 Jul 2024 20:00:52 +0800 Subject: [PATCH] fix: try to fix tencent hunyuan #336 --- controller/task.go | 2 +- relay/channel/tencent/adaptor.go | 20 ++- relay/channel/tencent/dto.go | 95 ++++++----- relay/channel/tencent/relay-tencent.go | 211 +++++++++++++------------ service/sse.go | 28 +++- 5 files changed, 200 insertions(+), 156 deletions(-) diff --git a/controller/task.go b/controller/task.go index fce9e7f..e94abaa 100644 --- a/controller/task.go +++ b/controller/task.go @@ -24,7 +24,7 @@ func UpdateTaskBulk() { //imageModel := "midjourney" for { time.Sleep(time.Duration(15) * time.Second) - common.SysLog("任务进度轮询开始") + common.SysLog(" 任务进度轮询开始") ctx := context.TODO() allTasks := model.GetAllUnFinishSyncTasks(500) platformTask := make(map[constant.TaskPlatform][]*model.Task) diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 470ec14..33eda3f 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -6,18 +6,26 @@ import ( "github.com/gin-gonic/gin" "io" "net/http" + "one-api/common" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" + "strconv" "strings" ) type Adaptor struct { - Sign string + Sign string + Action string + Version string + Timestamp int64 } func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { + a.Action = "ChatCompletions" + a.Version = "2023-09-01" + a.Timestamp = common.GetTimestamp() } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -27,7 +35,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) req.Header.Set("Authorization", a.Sign) - req.Header.Set("X-TC-Action", info.UpstreamModelName) + req.Header.Set("X-TC-Action", a.Action) + req.Header.Set("X-TC-Version", a.Version) + req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10)) return nil } @@ -37,15 +47,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen } apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") - appId, secretId, secretKey, err := parseTencentConfig(apiKey) + _, secretId, secretKey, err := parseTencentConfig(apiKey) if err != nil { return nil, err } tencentRequest := requestOpenAI2Tencent(*request) - tencentRequest.AppId = appId - tencentRequest.SecretId = secretId // we have to calculate the sign here - a.Sign = getTencentSign(*tencentRequest, secretKey) + a.Sign = getTencentSign(*tencentRequest, a, secretId, secretKey) return tencentRequest, nil } diff --git a/relay/channel/tencent/dto.go b/relay/channel/tencent/dto.go index d031432..395ccbb 100644 --- a/relay/channel/tencent/dto.go +++ b/relay/channel/tencent/dto.go @@ -1,62 +1,71 @@ package tencent -import "one-api/dto" - type TencentMessage struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"Role"` + Content string `json:"Content"` } type TencentChatRequest struct { - AppId int64 `json:"app_id"` // 腾讯云账号的 APPID - SecretId string `json:"secret_id"` // 官网 SecretId - // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 - // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 - Timestamp int64 `json:"timestamp"` - // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, - // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 - Expired int64 `json:"expired"` - QueryID string `json:"query_id"` //请求 Id,用于问题排查 - // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 - // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 - // 建议该参数和 top_p 只设置1个,不要同时更改 top_p - Temperature float64 `json:"temperature"` - // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 - // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 - // 建议该参数和 temperature 只设置1个,不要同时更改 - TopP float64 `json:"top_p"` - // Stream 0:同步,1:流式 (默认,协议:SSE) - // 同步请求超时:60s,如果内容较长建议使用流式 - Stream int `json:"stream"` - // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 - // 输入 content 总数最大支持 3000 token。 - Messages []TencentMessage `json:"messages"` - Model string `json:"model"` // 模型名称 + // 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。 + // 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。 + // + // 注意: + // 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。 + Model *string `json:"Model"` + // 聊天上下文信息。 + // 说明: + // 1. 长度最多为 40,按对话时间从旧到新在数组中排列。 + // 2. Message.Role 可选值:system、user、assistant。 + // 其中,system 角色可选,如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system(可选) user assistant user assistant user ...]。 + // 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。 + Messages []*TencentMessage `json:"Messages"` + // 流式调用开关。 + // 说明: + // 1. 未传值时默认为非流式调用(false)。 + // 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。 + // 3. 非流式调用时: + // 调用方式与普通 HTTP 请求无异。 + // 接口响应耗时较长,**如需更低时延建议设置为 true**。 + // 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。 + // + // 注意: + // 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。 + Stream *bool `json:"Stream"` + // 说明: + // 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。 + // 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。 + // 3. 非必要不建议使用,不合理的取值会影响效果。 + TopP *float64 `json:"TopP"` + // 说明: + // 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。 + // 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。 + // 3. 非必要不建议使用,不合理的取值会影响效果。 + Temperature *float64 `json:"Temperature"` } type TencentError struct { - Code int `json:"code"` - Message string `json:"message"` + Code int `json:"Code"` + Message string `json:"Message"` } type TencentUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` + PromptTokens int `json:"PromptTokens"` + CompletionTokens int `json:"CompletionTokens"` + TotalTokens int `json:"TotalTokens"` } type TencentResponseChoices struct { - FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 - Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 - Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 + FinishReason string `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 + Messages TencentMessage `json:"Message,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 + Delta TencentMessage `json:"Delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 } type TencentChatResponse struct { - Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 - Created string `json:"created,omitempty"` // unix 时间戳的字符串 - Id string `json:"id,omitempty"` // 会话 id - Usage dto.Usage `json:"usage,omitempty"` // token 数量 - Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 - Note string `json:"note,omitempty"` // 注释 - ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 + Choices []TencentResponseChoices `json:"Choices,omitempty"` // 结果 + Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串 + Id string `json:"Id,omitempty"` // 会话 id + Usage TencentUsage `json:"Usage,omitempty"` // token 数量 + Error TencentError `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 + Note string `json:"Note,omitempty"` // 注释 + ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 } diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index 87e0a2f..9858011 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -3,8 +3,8 @@ package tencent import ( "bufio" "crypto/hmac" - "crypto/sha1" - "encoding/base64" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -15,46 +15,28 @@ import ( "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" - "sort" "strconv" "strings" + "time" ) // https://cloud.tencent.com/document/product/1729/97732 func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest { - messages := make([]TencentMessage, 0, len(request.Messages)) + messages := make([]*TencentMessage, 0, len(request.Messages)) for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] - if message.Role == "system" { - messages = append(messages, TencentMessage{ - Role: "user", - Content: message.StringContent(), - }) - messages = append(messages, TencentMessage{ - Role: "assistant", - Content: "Okay", - }) - continue - } - messages = append(messages, TencentMessage{ + messages = append(messages, &TencentMessage{ Content: message.StringContent(), Role: message.Role, }) } - stream := 0 - if request.Stream { - stream = 1 - } return &TencentChatRequest{ - Timestamp: common.GetTimestamp(), - Expired: common.GetTimestamp() + 24*60*60, - QueryID: common.GetUUID(), - Temperature: request.Temperature, - TopP: request.TopP, - Stream: stream, + Temperature: &request.Temperature, + TopP: &request.TopP, + Stream: &request.Stream, Messages: messages, - Model: request.Model, + Model: &request.Model, } } @@ -62,7 +44,11 @@ func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextRespon fullTextResponse := dto.OpenAITextResponse{ Object: "chat.completion", Created: common.GetTimestamp(), - Usage: response.Usage, + Usage: dto.Usage{ + PromptTokens: response.Usage.PromptTokens, + CompletionTokens: response.Usage.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + }, } if len(response.Choices) > 0 { content, _ := json.Marshal(response.Choices[0].Messages.Content) @@ -99,64 +85,46 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) { var responseText string scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() + scanner.Split(bufio.ScanLines) + service.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var TencentResponse TencentChatResponse - err := json.Unmarshal([]byte(data), &TencentResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response := streamResponseTencent2OpenAI(&TencentResponse) - if len(response.Choices) != 0 { - responseText += response.Choices[0].Delta.GetContentString() - } - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || !strings.HasPrefix(data, "data:") { + continue } - }) + data = strings.TrimPrefix(data, "data:") + + var tencentResponse TencentChatResponse + err := json.Unmarshal([]byte(data), &tencentResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response := streamResponseTencent2OpenAI(&tencentResponse) + if len(response.Choices) != 0 { + responseText += response.Choices[0].Delta.GetContentString() + } + + err = service.ObjectData(c, response) + if err != nil { + common.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + common.SysError("error reading stream: " + err.Error()) + } + + service.Done(c) + err := resp.Body.Close() if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } + return nil, responseText } @@ -206,29 +174,62 @@ func parseTencentConfig(config string) (appId int64, secretId string, secretKey return } -func getTencentSign(req TencentChatRequest, secretKey string) string { - params := make([]string, 0) - params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) - params = append(params, "secret_id="+req.SecretId) - params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) - params = append(params, "query_id="+req.QueryID) - params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) - params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) - params = append(params, "stream="+strconv.Itoa(req.Stream)) - params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) - - var messageStr string - for _, msg := range req.Messages { - messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) - } - messageStr = strings.TrimSuffix(messageStr, ",") - params = append(params, "messages=["+messageStr+"]") - - sort.Sort(sort.StringSlice(params)) - url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") - mac := hmac.New(sha1.New, []byte(secretKey)) - signURL := url - mac.Write([]byte(signURL)) - sign := mac.Sum([]byte(nil)) - return base64.StdEncoding.EncodeToString(sign) +func sha256hex(s string) string { + b := sha256.Sum256([]byte(s)) + return hex.EncodeToString(b[:]) +} + +func hmacSha256(s, key string) string { + hashed := hmac.New(sha256.New, []byte(key)) + hashed.Write([]byte(s)) + return string(hashed.Sum(nil)) +} + +func getTencentSign(req TencentChatRequest, adaptor *Adaptor, secId, secKey string) string { + // build canonical request string + host := "hunyuan.tencentcloudapi.com" + httpRequestMethod := "POST" + canonicalURI := "/" + canonicalQueryString := "" + canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n", + "application/json", host, strings.ToLower(adaptor.Action)) + signedHeaders := "content-type;host;x-tc-action" + payload, _ := json.Marshal(req) + hashedRequestPayload := sha256hex(string(payload)) + canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", + httpRequestMethod, + canonicalURI, + canonicalQueryString, + canonicalHeaders, + signedHeaders, + hashedRequestPayload) + // build string to sign + algorithm := "TC3-HMAC-SHA256" + requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10) + timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64) + t := time.Unix(timestamp, 0).UTC() + // must be the format 2006-01-02, ref to package time for more info + date := t.Format("2006-01-02") + credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan") + hashedCanonicalRequest := sha256hex(canonicalRequest) + string2sign := fmt.Sprintf("%s\n%s\n%s\n%s", + algorithm, + requestTimestamp, + credentialScope, + hashedCanonicalRequest) + + // sign string + secretDate := hmacSha256(date, "TC3"+secKey) + secretService := hmacSha256("hunyuan", secretDate) + secretKey := hmacSha256("tc3_request", secretService) + signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey))) + + // build authorization + authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", + algorithm, + secId, + credentialScope, + signedHeaders, + signature) + return authorization } diff --git a/service/sse.go b/service/sse.go index 4e86bad..2d531a4 100644 --- a/service/sse.go +++ b/service/sse.go @@ -1,6 +1,12 @@ package service -import "github.com/gin-gonic/gin" +import ( + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "one-api/common" + "strings" +) func SetEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("Content-Type", "text/event-stream") @@ -9,3 +15,23 @@ func SetEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("Transfer-Encoding", "chunked") c.Writer.Header().Set("X-Accel-Buffering", "no") } + +func StringData(c *gin.Context, str string) { + str = strings.TrimPrefix(str, "data: ") + str = strings.TrimSuffix(str, "\r") + c.Render(-1, common.CustomEvent{Data: "data: " + str}) + c.Writer.Flush() +} + +func ObjectData(c *gin.Context, object interface{}) error { + jsonData, err := json.Marshal(object) + if err != nil { + return fmt.Errorf("error marshalling object: %w", err) + } + StringData(c, string(jsonData)) + return nil +} + +func Done(c *gin.Context) { + StringData(c, "[DONE]") +}