diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 115558a5..edfccf4a 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -21,4 +21,5 @@ const ( AvailableModels = "available_models" KeyRequestBody = "key_request_body" SystemPrompt = "system_prompt" + LoraId = "lora_id" ) diff --git a/relay/adaptor/xunfei/main.go b/relay/adaptor/xunfei/main.go index 3984ba5a..164799dc 100644 --- a/relay/adaptor/xunfei/main.go +++ b/relay/adaptor/xunfei/main.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "net/url" + "strconv" "strings" "time" @@ -28,7 +29,7 @@ import ( // https://console.xfyun.cn/services/cbm // https://www.xfyun.cn/doc/spark/Web.html -func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { +func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string, xunfeiPatchId string) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { messages = append(messages, Message{ @@ -38,6 +39,9 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string } xunfeiRequest := ChatRequest{} xunfeiRequest.Header.AppId = xunfeiAppId + if xunfeiPatchId != "" { + xunfeiRequest.Header.PatchId = []string{xunfeiPatchId} + } xunfeiRequest.Parameter.Chat.Domain = domain xunfeiRequest.Parameter.Chat.Temperature = request.Temperature xunfeiRequest.Parameter.Chat.TopK = request.N @@ -93,7 +97,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { FinishReason: constant.StopFinishReason, } fullTextResponse := openai.TextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Id: response.Header.Sid, Object: "chat.completion", Created: helper.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, @@ -102,7 +106,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { return &fullTextResponse } -func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { +func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse, usage *model.Usage) *openai.ChatCompletionsStreamResponse { if len(xunfeiResponse.Payload.Choices.Text) == 0 { xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{ { @@ -122,6 +126,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl Created: helper.GetTimestamp(), Model: "SparkDesk", Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + Usage: usage, } return &response } @@ -153,11 +158,12 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { } func StreamHandler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { - domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret) - dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) + domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret, textRequest.Model) + dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId, meta.LoraId) if err != nil { return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil } + common.SetEventStreamHeaders(c) var usage model.Usage c.Stream(func(w io.Writer) bool { @@ -166,7 +172,23 @@ func StreamHandler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpe usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens - response := streamResponseXunfei2OpenAI(&xunfeiResponse) + if xunfeiResponse.Header.Code != 0 { + errMessage := fmt.Sprintf("Xunfei request failed with Sid: %s code: %d, msg: %s", xunfeiResponse.Header.Sid, xunfeiResponse.Header.Code, xunfeiResponse.Header.Message) + logger.SysError(errMessage) + mStr, err := json.Marshal(map[string]interface{}{ + "error": map[string]interface{}{ + "message": errMessage, + "code": xunfeiResponse.Header.Code, + }, + }) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(mStr), Event: "error"}) + return false // 停止流式响应 + } + response := streamResponseXunfei2OpenAI(&xunfeiResponse, &usage) jsonResponse, err := json.Marshal(response) if err != nil { logger.SysError("error marshalling stream response: " + err.Error()) @@ -183,8 +205,8 @@ func StreamHandler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpe } func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { - domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret) - dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) + domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret, textRequest.Model) + dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId, meta.LoraId) if err != nil { return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil } @@ -205,6 +227,10 @@ func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIReq case stop = <-stopChan: } } + if xunfeiResponse.Header.Code != 0 { + return openai.ErrorWrapper(errors.New("xunfei response error: sid: "+xunfeiResponse.Header.Sid), strconv.Itoa(xunfeiResponse.Header.Code), http.StatusInternalServerError), nil + + } if len(xunfeiResponse.Payload.Choices.Text) == 0 { return openai.ErrorWrapper(errors.New("xunfei empty response detected"), "xunfei_empty_response_detected", http.StatusInternalServerError), nil } @@ -220,7 +246,7 @@ func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIReq return nil, &usage } -func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) { +func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, appId, patchId string) (chan ChatResponse, chan bool, error) { d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } @@ -228,7 +254,7 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, if err != nil || resp.StatusCode != 101 { return nil, nil, err } - data := requestOpenAI2Xunfei(textRequest, appId, domain) + data := requestOpenAI2Xunfei(textRequest, appId, domain, patchId) err = conn.WriteJSON(data) if err != nil { return nil, nil, err @@ -300,7 +326,7 @@ func apiVersion2domain(apiVersion string) string { return "general" + apiVersion } -func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) { +func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string, modelName string) (string, string) { var authUrl string domain := apiVersion2domain(apiVersion) switch apiVersion { @@ -310,6 +336,13 @@ func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (strin case "v3.5-32K": authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/max-32k"), apiKey, apiSecret) break + case "maas": + domain = modelName + authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://maas-api.cn-huabei-1.xf-yun.com/v1.1/chat"), apiKey, apiSecret) + case "xingchen": + domain = modelName + authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://xingcheng-api.cn-huabei-1.xf-yun.com/v1.1/chat"), apiKey, apiSecret) + default: authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) } diff --git a/relay/adaptor/xunfei/model.go b/relay/adaptor/xunfei/model.go index c9fb1bb8..e22b22e4 100644 --- a/relay/adaptor/xunfei/model.go +++ b/relay/adaptor/xunfei/model.go @@ -15,7 +15,8 @@ type Functions struct { type ChatRequest struct { Header struct { - AppId string `json:"app_id"` + AppId string `json:"app_id"` + PatchId []string `json:"patch_id,omitempty"` } `json:"header"` Parameter struct { Chat struct { diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index bcbe1045..461c244c 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -31,6 +31,9 @@ type Meta struct { RequestURLPath string PromptTokens int // only for DoResponse SystemPrompt string + + //Lora_id + LoraId string } func GetByContext(c *gin.Context) *Meta { @@ -48,6 +51,7 @@ func GetByContext(c *gin.Context) *Meta { APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), RequestURLPath: c.Request.URL.String(), SystemPrompt: c.GetString(ctxkey.SystemPrompt), + LoraId: c.Request.Header.Get(ctxkey.LoraId), } cfg, ok := c.Get(ctxkey.Config) if ok {