package xunfei import ( "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "io" "net/http" "net/url" "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" "strings" "time" ) // https://console.xfyun.cn/services/cbm // https://www.xfyun.cn/doc/spark/Web.html func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest { messages := make([]XunfeiMessage, 0, len(request.Messages)) shouldCovertSystemMessage := !strings.HasSuffix(request.Model, "3.5") for _, message := range request.Messages { if message.Role == "system" && shouldCovertSystemMessage { messages = append(messages, XunfeiMessage{ Role: "user", Content: message.StringContent(), }) messages = append(messages, XunfeiMessage{ Role: "assistant", Content: "Okay", }) } else { messages = append(messages, XunfeiMessage{ Role: message.Role, Content: message.StringContent(), }) } } xunfeiRequest := XunfeiChatRequest{} xunfeiRequest.Header.AppId = xunfeiAppId xunfeiRequest.Parameter.Chat.Domain = domain xunfeiRequest.Parameter.Chat.Temperature = request.Temperature xunfeiRequest.Parameter.Chat.TopK = request.N xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens xunfeiRequest.Payload.Message.Text = messages return &xunfeiRequest } func responseXunfei2OpenAI(response *XunfeiChatResponse) *dto.OpenAITextResponse { if len(response.Payload.Choices.Text) == 0 { response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ { Content: "", }, } } content, _ := json.Marshal(response.Payload.Choices.Text[0].Content) choice := dto.OpenAITextResponseChoice{ Index: 0, Message: dto.Message{ Role: "assistant", Content: content, }, FinishReason: relaycommon.StopFinishReason, } fullTextResponse := dto.OpenAITextResponse{ Object: "chat.completion", Created: common.GetTimestamp(), Choices: []dto.OpenAITextResponseChoice{choice}, Usage: response.Payload.Usage.Text, } return &fullTextResponse } func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCompletionsStreamResponse { if len(xunfeiResponse.Payload.Choices.Text) == 0 { xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ { Content: "", }, } } var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content if xunfeiResponse.Payload.Choices.Status == 2 { choice.FinishReason = &relaycommon.StopFinishReason } response := dto.ChatCompletionsStreamResponse{ Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "SparkDesk", Choices: []dto.ChatCompletionsStreamResponseChoice{choice}, } return &response } func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { HmacWithShaToBase64 := func(algorithm, data, key string) string { mac := hmac.New(sha256.New, []byte(key)) mac.Write([]byte(data)) encodeData := mac.Sum(nil) return base64.StdEncoding.EncodeToString(encodeData) } ul, err := url.Parse(hostUrl) if err != nil { fmt.Println(err) } date := time.Now().UTC().Format(time.RFC1123) signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} sign := strings.Join(signString, "\n") sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret) authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha) authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) v := url.Values{} v.Add("host", ul.Host) v.Add("date", date) v.Add("authorization", authorization) callUrl := hostUrl + "?" + v.Encode() return callUrl } func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } service.SetEventStreamHeaders(c) var usage dto.Usage c.Stream(func(w io.Writer) bool { select { case xunfeiResponse := <-dataChan: usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens response := streamResponseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) return true case <-stopChan: c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) return false } }) return nil, &usage } func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } var usage dto.Usage var content string var xunfeiResponse XunfeiChatResponse stop := false for !stop { select { case xunfeiResponse = <-dataChan: if len(xunfeiResponse.Payload.Choices.Text) == 0 { continue } content += xunfeiResponse.Payload.Choices.Text[0].Content usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens case stop = <-stopChan: } } if len(xunfeiResponse.Payload.Choices.Text) == 0 { xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ { Content: "", }, } } xunfeiResponse.Payload.Choices.Text[0].Content = content response := responseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") _, _ = c.Writer.Write(jsonResponse) return nil, &usage } func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } conn, resp, err := d.Dial(authUrl, nil) if err != nil || resp.StatusCode != 101 { return nil, nil, err } data := requestOpenAI2Xunfei(textRequest, appId, domain) err = conn.WriteJSON(data) if err != nil { return nil, nil, err } dataChan := make(chan XunfeiChatResponse) stopChan := make(chan bool) go func() { for { _, msg, err := conn.ReadMessage() if err != nil { common.SysError("error reading stream response: " + err.Error()) break } var response XunfeiChatResponse err = json.Unmarshal(msg, &response) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) break } dataChan <- response if response.Payload.Choices.Status == 2 { err := conn.Close() if err != nil { common.SysError("error closing websocket connection: " + err.Error()) } break } } stopChan <- true }() return dataChan, stopChan, nil } func apiVersion2domain(apiVersion string) string { switch apiVersion { case "v1.1": return "general" case "v2.1": return "generalv2" case "v3.1": return "generalv3" case "v3.5": return "generalv3.5" } return "general" + apiVersion } func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) { apiVersion := getAPIVersion(c, modelName) domain := apiVersion2domain(apiVersion) authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) return domain, authUrl } func getAPIVersion(c *gin.Context, modelName string) string { query := c.Request.URL.Query() apiVersion := query.Get("api-version") if apiVersion != "" { return apiVersion } parts := strings.Split(modelName, "-") if len(parts) == 2 { apiVersion = parts[1] return apiVersion } apiVersion = c.GetString("api_version") if apiVersion != "" { return apiVersion } apiVersion = "v1.1" common.SysLog("api_version not found, using default: " + apiVersion) return apiVersion }