new-api/relay/channel/tencent/relay-tencent.go
2024-07-08 23:42:16 +08:00

242 lines
7.0 KiB
Go

package tencent
import (
"bufio"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strconv"
"strings"
"time"
)
// https://cloud.tencent.com/document/product/1729/97732
func requestOpenAI2Tencent(a *Adaptor, request dto.GeneralOpenAIRequest) *TencentChatRequest {
messages := make([]*TencentMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
messages = append(messages, &TencentMessage{
Content: message.StringContent(),
Role: message.Role,
})
}
var req = TencentChatRequest{
Stream: &request.Stream,
Messages: messages,
Model: &request.Model,
}
if request.TopP != 0 {
req.TopP = &request.TopP
}
if request.Temperature != 0 {
req.Temperature = &request.Temperature
}
return &req
}
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: response.Id,
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: dto.Usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
},
}
if len(response.Choices) > 0 {
content, _ := json.Marshal(response.Choices[0].Messages.Content)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: content,
},
FinishReason: response.Choices[0].FinishReason,
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.ChatCompletionsStreamResponse {
response := dto.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "tencent-hunyuan",
}
if len(TencentResponse.Choices) > 0 {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(TencentResponse.Choices[0].Delta.Content)
if TencentResponse.Choices[0].FinishReason == "stop" {
choice.FinishReason = &relaycommon.StopFinishReason
}
response.Choices = append(response.Choices, choice)
}
return &response
}
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
var tencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &tencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
}
response := streamResponseTencent2OpenAI(&tencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.GetContentString()
}
err = service.ObjectData(c, response)
if err != nil {
common.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error())
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var tencentSb TencentChatResponseSB
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &tencentSb)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if tencentSb.Response.Error.Code != 0 {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: tencentSb.Response.Error.Message,
Code: tencentSb.Response.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
parts := strings.Split(config, "|")
if len(parts) != 3 {
err = errors.New("invalid tencent config")
return
}
appId, err = strconv.ParseInt(parts[0], 10, 64)
secretId = parts[1]
secretKey = parts[2]
return
}
func sha256hex(s string) string {
b := sha256.Sum256([]byte(s))
return hex.EncodeToString(b[:])
}
func hmacSha256(s, key string) string {
hashed := hmac.New(sha256.New, []byte(key))
hashed.Write([]byte(s))
return string(hashed.Sum(nil))
}
func getTencentSign(req TencentChatRequest, adaptor *Adaptor, secId, secKey string) string {
// build canonical request string
host := "hunyuan.tencentcloudapi.com"
httpRequestMethod := "POST"
canonicalURI := "/"
canonicalQueryString := ""
canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
"application/json", host, strings.ToLower(adaptor.Action))
signedHeaders := "content-type;host;x-tc-action"
payload, _ := json.Marshal(req)
hashedRequestPayload := sha256hex(string(payload))
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
httpRequestMethod,
canonicalURI,
canonicalQueryString,
canonicalHeaders,
signedHeaders,
hashedRequestPayload)
// build string to sign
algorithm := "TC3-HMAC-SHA256"
requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
t := time.Unix(timestamp, 0).UTC()
// must be the format 2006-01-02, ref to package time for more info
date := t.Format("2006-01-02")
credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
hashedCanonicalRequest := sha256hex(canonicalRequest)
string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
algorithm,
requestTimestamp,
credentialScope,
hashedCanonicalRequest)
// sign string
secretDate := hmacSha256(date, "TC3"+secKey)
secretService := hmacSha256("hunyuan", secretDate)
secretKey := hmacSha256("tc3_request", secretService)
signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
// build authorization
authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
algorithm,
secId,
credentialScope,
signedHeaders,
signature)
return authorization
}