From ff24d0fab87c8adfb4cee8e26a5d92b0dbf006c4 Mon Sep 17 00:00:00 2001 From: ybyang Date: Mon, 25 Nov 2024 14:27:03 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81LoraID=20=EF=BC=8C=20?= =?UTF-8?q?=E9=80=9A=E8=BF=87OpenAI=20Client=E8=AE=BE=E7=BD=AE=20ExtraHead?= =?UTF-8?q?er=20=E5=B9=B6=E5=9C=A8=E8=AE=AF=E9=A3=9E=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E4=B8=AD=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/ctxkey/key.go | 1 + relay/adaptor/xunfei/main.go | 11 ++++++----- relay/adaptor/xunfei/model.go | 3 ++- relay/meta/relay_meta.go | 4 ++++ 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 115558a5..edfccf4a 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -21,4 +21,5 @@ const ( AvailableModels = "available_models" KeyRequestBody = "key_request_body" SystemPrompt = "system_prompt" + LoraId = "lora_id" ) diff --git a/relay/adaptor/xunfei/main.go b/relay/adaptor/xunfei/main.go index 842f60da..516bea5b 100644 --- a/relay/adaptor/xunfei/main.go +++ b/relay/adaptor/xunfei/main.go @@ -28,7 +28,7 @@ import ( // https://console.xfyun.cn/services/cbm // https://www.xfyun.cn/doc/spark/Web.html -func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest { +func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string, xunfeiPatchId string) *ChatRequest { messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { messages = append(messages, Message{ @@ -38,6 +38,7 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string } xunfeiRequest := ChatRequest{} xunfeiRequest.Header.AppId = xunfeiAppId + xunfeiRequest.Header.PatchId = xunfeiPatchId xunfeiRequest.Parameter.Chat.Domain = domain xunfeiRequest.Parameter.Chat.Temperature = request.Temperature xunfeiRequest.Parameter.Chat.TopK = request.N @@ -154,7 +155,7 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { func StreamHandler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret, textRequest.Model) - dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) + dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId, meta.LoraId) if err != nil { return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil } @@ -184,7 +185,7 @@ func StreamHandler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpe func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret, textRequest.Model) - dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) + dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId, meta.LoraId) if err != nil { return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil } @@ -220,7 +221,7 @@ func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIReq return nil, &usage } -func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) { +func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, appId, patchId string) (chan ChatResponse, chan bool, error) { d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } @@ -228,7 +229,7 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, if err != nil || resp.StatusCode != 101 { return nil, nil, err } - data := requestOpenAI2Xunfei(textRequest, appId, domain) + data := requestOpenAI2Xunfei(textRequest, appId, domain, patchId) err = conn.WriteJSON(data) if err != nil { return nil, nil, err diff --git a/relay/adaptor/xunfei/model.go b/relay/adaptor/xunfei/model.go index c9fb1bb8..3741a6e5 100644 --- a/relay/adaptor/xunfei/model.go +++ b/relay/adaptor/xunfei/model.go @@ -15,7 +15,8 @@ type Functions struct { type ChatRequest struct { Header struct { - AppId string `json:"app_id"` + AppId string `json:"app_id"` + PatchId string `json:"patch_id,omitempty"` } `json:"header"` Parameter struct { Chat struct { diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index bcbe1045..461c244c 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -31,6 +31,9 @@ type Meta struct { RequestURLPath string PromptTokens int // only for DoResponse SystemPrompt string + + //Lora_id + LoraId string } func GetByContext(c *gin.Context) *Meta { @@ -48,6 +51,7 @@ func GetByContext(c *gin.Context) *Meta { APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), RequestURLPath: c.Request.URL.String(), SystemPrompt: c.GetString(ctxkey.SystemPrompt), + LoraId: c.Request.Header.Get(ctxkey.LoraId), } cfg, ok := c.Get(ctxkey.Config) if ok {