From e0ed59bfe330b5510da681337ee8c14254e8fbc5 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sat, 6 Jul 2024 01:32:40 +0800 Subject: [PATCH] feat: support dify (close #299) --- common/constants.go | 1 + controller/channel-test.go | 6 +- relay/channel/dify/adaptor.go | 56 +++++++++ relay/channel/dify/constants.go | 5 + relay/channel/dify/dto.go | 35 ++++++ relay/channel/dify/relay-dify.go | 154 +++++++++++++++++++++++++ relay/constant/api_type.go | 3 + relay/relay_adaptor.go | 3 + web/src/constants/channel.constants.js | 1 + 9 files changed, 263 insertions(+), 1 deletion(-) create mode 100644 relay/channel/dify/adaptor.go create mode 100644 relay/channel/dify/constants.go create mode 100644 relay/channel/dify/dto.go create mode 100644 relay/channel/dify/relay-dify.go diff --git a/common/constants.go b/common/constants.go index 7454f57..386f37d 100644 --- a/common/constants.go +++ b/common/constants.go @@ -210,6 +210,7 @@ const ( ChannelTypeCohere = 34 ChannelTypeMiniMax = 35 ChannelTypeSunoAPI = 36 + ChannelTypeDify = 37 ChannelTypeDummy // this one is only for count, do not add any channel after this diff --git a/controller/channel-test.go b/controller/channel-test.go index 1beb5e1..0113f7e 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -67,7 +67,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr if channel.TestModel != nil && *channel.TestModel != "" { testModel = *channel.TestModel } else { - testModel = adaptor.GetModelList()[0] + if len(adaptor.GetModelList()) > 0 { + testModel = adaptor.GetModelList()[0] + } else { + testModel = "gpt-3.5-turbo" + } } } else { modelMapping := *channel.ModelMapping diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go new file mode 100644 index 0000000..99f3792 --- /dev/null +++ b/relay/channel/dify/adaptor.go @@ -0,0 +1,56 @@ +package dify + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Header.Set("Authorization", "Bearer "+info.ApiKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return requestOpenAI2Dify(*request), nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = difyStreamHandler(c, resp, info) + } else { + err, usage = difyHandler(c, resp, info) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/dify/constants.go b/relay/channel/dify/constants.go new file mode 100644 index 0000000..db3e67c --- /dev/null +++ b/relay/channel/dify/constants.go @@ -0,0 +1,5 @@ +package dify + +var ModelList []string + +var ChannelName = "dify" diff --git a/relay/channel/dify/dto.go b/relay/channel/dify/dto.go new file mode 100644 index 0000000..74ef604 --- /dev/null +++ b/relay/channel/dify/dto.go @@ -0,0 +1,35 @@ +package dify + +import "one-api/dto" + +type DifyChatRequest struct { + Inputs map[string]interface{} `json:"inputs"` + Query string `json:"query"` + ResponseMode string `json:"response_mode"` + User string `json:"user"` + AutoGenerateName bool `json:"auto_generate_name"` +} + +type DifyMetaData struct { + Usage dto.Usage `json:"usage"` +} + +type DifyData struct { + WorkflowId string `json:"workflow_id"` + NodeId string `json:"node_id"` +} + +type DifyChatCompletionResponse struct { + ConversationId string `json:"conversation_id"` + Answers string `json:"answers"` + CreateAt int64 `json:"create_at"` + MetaData DifyMetaData `json:"metadata"` +} + +type DifyChunkChatCompletionResponse struct { + Event string `json:"event"` + ConversationId string `json:"conversation_id"` + Answer string `json:"answer"` + Data DifyData `json:"data"` + MetaData DifyMetaData `json:"metadata"` +} diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go new file mode 100644 index 0000000..5553dfa --- /dev/null +++ b/relay/channel/dify/relay-dify.go @@ -0,0 +1,154 @@ +package dify + +import ( + "bufio" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" + "strings" +) + +func requestOpenAI2Dify(request dto.GeneralOpenAIRequest) *DifyChatRequest { + content := "" + for _, message := range request.Messages { + if message.Role == "system" { + content += "SYSTEM: \n" + message.StringContent() + "\n" + } else if message.Role == "assistant" { + content += "ASSISTANT: \n" + message.StringContent() + "\n" + } else { + content += "USER: \n" + message.StringContent() + "\n" + } + } + mode := "blocking" + if request.Stream { + mode = "streaming" + } + user := request.User + if user == "" { + user = "api-user" + } + return &DifyChatRequest{ + Inputs: make(map[string]interface{}), + Query: content, + ResponseMode: mode, + User: user, + AutoGenerateName: false, + } +} + +func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse { + response := dto.ChatCompletionsStreamResponse{ + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "dify", + } + var choice dto.ChatCompletionsStreamResponseChoice + if difyResponse.Event == "workflow_started" { + choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n") + } else if difyResponse.Event == "node_started" { + choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n") + } else if difyResponse.Event == "message" { + choice.Delta.SetContentString(difyResponse.Answer) + } + response.Choices = append(response.Choices, choice) + return &response +} + +func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var responseText string + usage := &dto.Usage{} + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + service.SetEventStreamHeaders(c) + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || !strings.HasPrefix(data, "data:") { + continue + } + data = strings.TrimPrefix(data, "data:") + var difyResponse DifyChunkChatCompletionResponse + err := json.Unmarshal([]byte(data), &difyResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + var openaiResponse dto.ChatCompletionsStreamResponse + if difyResponse.Event == "message_end" { + usage = &difyResponse.MetaData.Usage + break + } else if difyResponse.Event == "error" { + break + } else { + openaiResponse = *streamResponseDify2OpenAI(difyResponse) + if len(openaiResponse.Choices) != 0 { + responseText += openaiResponse.Choices[0].Delta.GetContentString() + } + } + err = service.ObjectData(c, openaiResponse) + if err != nil { + common.SysError(err.Error()) + } + } + if err := scanner.Err(); err != nil { + common.SysError("error reading stream: " + err.Error()) + } + service.Done(c) + err := resp.Body.Close() + if err != nil { + //return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + common.SysError("close_response_body_failed: " + err.Error()) + } + if usage.TotalTokens == 0 { + usage.PromptTokens = info.PromptTokens + usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + return nil, usage +} + +func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var difyResponse DifyChatCompletionResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &difyResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + fullTextResponse := dto.OpenAITextResponse{ + Id: difyResponse.ConversationId, + Object: "chat.completion", + Created: common.GetTimestamp(), + Usage: difyResponse.MetaData.Usage, + } + content, _ := json.Marshal(difyResponse.Answers) + choice := dto.OpenAITextResponseChoice{ + Index: 0, + Message: dto.Message{ + Role: "assistant", + Content: content, + }, + FinishReason: "stop", + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &difyResponse.MetaData.Usage +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 943c407..15ba541 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -20,6 +20,7 @@ const ( APITypePerplexity APITypeAws APITypeCohere + APITypeDify APITypeDummy // this one is only for count, do not add any channel after this ) @@ -57,6 +58,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeAws case common.ChannelTypeCohere: apiType = APITypeCohere + case common.ChannelTypeDify: + apiType = APITypeDify } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index bfa13f4..36edef3 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -8,6 +8,7 @@ import ( "one-api/relay/channel/baidu" "one-api/relay/channel/claude" "one-api/relay/channel/cohere" + "one-api/relay/channel/dify" "one-api/relay/channel/gemini" "one-api/relay/channel/ollama" "one-api/relay/channel/openai" @@ -53,6 +54,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &aws.Adaptor{} case constant.APITypeCohere: return &cohere.Adaptor{} + case constant.APITypeDify: + return &dify.Adaptor{} } return nil } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index e67dbc6..9383f21 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -104,6 +104,7 @@ export const CHANNEL_OPTIONS = [ { key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' }, { key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' }, { key: 35, text: 'MiniMax', value: 35, color: 'green', label: 'MiniMax' }, + { key: 37, text: 'Dify', value: 37, color: 'green', label: 'Dify' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' }, { key: 22,