From b0e234e8f58c83ad558fcc86b487367ea788dc29 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Mon, 8 Jul 2024 01:27:57 +0800 Subject: [PATCH] feat: support stream_options --- README.md | 2 +- constant/env.go | 3 +++ dto/text_request.go | 5 ++++ dto/text_response.go | 1 + relay/channel/ollama/adaptor.go | 6 +++-- relay/channel/openai/adaptor.go | 24 ++++++++++++++--- relay/channel/openai/relay-openai.go | 24 ++++++++++++++--- relay/channel/perplexity/adaptor.go | 6 +++-- relay/channel/zhipu_4v/adaptor.go | 8 +++--- relay/common/relay_info.go | 40 +++++++++++++++------------- relay/relay-text.go | 33 ++++++++--------------- 11 files changed, 97 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index 4eed244..4aca707 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ ## 比原版One API多出的配置 - `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒 - `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`, 可选值为 `true` 和 `false` - +- `FORCE_STREAM_OPTION`:覆盖客户端stream_options参数,请求上游返回流模式usage,目前仅支持 `OpenAI` 渠道类型 ## 部署 ### 部署要求 - 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机) diff --git a/constant/env.go b/constant/env.go index b08bc04..96483fe 100644 --- a/constant/env.go +++ b/constant/env.go @@ -6,3 +6,6 @@ import ( var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30) var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true) + +// ForceStreamOption 覆盖请求参数,强制返回usage信息 +var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true) diff --git a/dto/text_request.go b/dto/text_request.go index 0f696fc..e12c9b4 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -11,6 +11,7 @@ type GeneralOpenAIRequest struct { Messages []Message `json:"messages,omitempty"` Prompt any `json:"prompt,omitempty"` Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` @@ -43,6 +44,10 @@ type OpenAIFunction struct { Parameters any `json:"parameters,omitempty"` } +type StreamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` +} + func (r GeneralOpenAIRequest) GetMaxTokens() int64 { return int64(r.MaxTokens) } diff --git a/dto/text_response.go b/dto/text_response.go index 53c87eb..9b45368 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -106,6 +106,7 @@ type ChatCompletionsStreamResponse struct { type ChatCompletionsStreamResponseSimple struct { Choices []ChatCompletionsStreamResponseChoice `json:"choices"` + Usage *Usage `json:"usage"` } type CompletionsStreamResponse struct { diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 4bf1d61..76de148 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -59,8 +59,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string - err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info) + if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 { + usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + } } else { if info.RelayMode == relayconstant.RelayModeEmbeddings { err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 0c1ce25..ba4b66a 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/ai360" @@ -19,7 +20,8 @@ import ( ) type Adaptor struct { - ChannelType int + ChannelType int + SupportStreamOptions bool } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { @@ -31,6 +33,7 @@ func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { a.ChannelType = info.ChannelType + a.SupportStreamOptions = info.SupportStreamOptions } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -78,6 +81,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen if request == nil { return nil, errors.New("request is nil") } + // 如果不支持StreamOptions,将StreamOptions设置为nil + if !a.SupportStreamOptions { + request.StreamOptions = nil + } else { + // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions + if constant.ForceStreamOption { + request.StreamOptions = &dto.StreamOptions{ + IncludeUsage: true, + } + } + } return request, nil } @@ -89,9 +103,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string var toolCount int - err, responseText, toolCount = OpenaiStreamHandler(c, resp, info) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - usage.CompletionTokens += toolCount * 7 + err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info) + if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 { + usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage.CompletionTokens += toolCount * 7 + } } else { err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index b397268..7c96f69 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -18,9 +18,10 @@ import ( "time" ) -func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string, int) { +func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) { //checkSensitive := constant.ShouldCheckCompletionSensitive() var responseTextBuilder strings.Builder + var usage dto.Usage toolCount := 0 scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -62,17 +63,26 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. streamItems = append(streamItems, data) } } + // 计算token streamResp := "[" + strings.Join(streamItems, ",") + "]" switch info.RelayMode { case relayconstant.RelayModeChatCompletions: var streamResponses []dto.ChatCompletionsStreamResponseSimple err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) if err != nil { + // 一次性解析失败,逐个解析 common.SysError("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.ChatCompletionsStreamResponseSimple err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) if err == nil { + if streamResponse.Usage != nil { + if streamResponse.Usage.TotalTokens != 0 { + usage.PromptTokens += streamResponse.Usage.PromptTokens + usage.CompletionTokens += streamResponse.Usage.CompletionTokens + usage.TotalTokens += streamResponse.Usage.TotalTokens + } + } for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) if choice.Delta.ToolCalls != nil { @@ -89,6 +99,13 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } } else { for _, streamResponse := range streamResponses { + if streamResponse.Usage != nil { + if streamResponse.Usage.TotalTokens != 0 { + usage.PromptTokens += streamResponse.Usage.PromptTokens + usage.CompletionTokens += streamResponse.Usage.CompletionTokens + usage.TotalTokens += streamResponse.Usage.TotalTokens + } + } for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) if choice.Delta.ToolCalls != nil { @@ -107,6 +124,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. var streamResponses []dto.CompletionsStreamResponse err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) if err != nil { + // 一次性解析失败,逐个解析 common.SysError("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.CompletionsStreamResponse @@ -159,10 +177,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. }) err := resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount + return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount } wg.Wait() - return nil, responseTextBuilder.String(), toolCount + return nil, &usage, responseTextBuilder.String(), toolCount } func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index c3972d5..3c65b2d 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -55,8 +55,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string - err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info) + if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 { + usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + } } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index eaf3087..508861f 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -57,9 +57,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string var toolCount int - err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - usage.CompletionTokens += toolCount * 7 + err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info) + if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 { + usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage.CompletionTokens += toolCount * 7 + } } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index fc01d17..cdad472 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -9,24 +9,25 @@ import ( ) type RelayInfo struct { - ChannelType int - ChannelId int - TokenId int - UserId int - Group string - TokenUnlimited bool - StartTime time.Time - FirstResponseTime time.Time - ApiType int - IsStream bool - RelayMode int - UpstreamModelName string - RequestURLPath string - ApiVersion string - PromptTokens int - ApiKey string - Organization string - BaseUrl string + ChannelType int + ChannelId int + TokenId int + UserId int + Group string + TokenUnlimited bool + StartTime time.Time + FirstResponseTime time.Time + ApiType int + IsStream bool + RelayMode int + UpstreamModelName string + RequestURLPath string + ApiVersion string + PromptTokens int + ApiKey string + Organization string + BaseUrl string + SupportStreamOptions bool } func GenRelayInfo(c *gin.Context) *RelayInfo { @@ -65,6 +66,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { if info.ChannelType == common.ChannelTypeAzure { info.ApiVersion = GetAPIVersion(c) } + if info.ChannelType == common.ChannelTypeOpenAI { + info.SupportStreamOptions = true + } return info } diff --git a/relay/relay-text.go b/relay/relay-text.go index 28b5d35..7d3ad2c 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -77,7 +77,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { // map model name modelMapping := c.GetString("model_mapping") - isModelMapped := false + //isModelMapped := false if modelMapping != "" && modelMapping != "{}" { modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) @@ -87,7 +87,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { if modelMap[textRequest.Model] != "" { textRequest.Model = modelMap[textRequest.Model] // set upstream model name - isModelMapped = true + //isModelMapped = true } } relayInfo.UpstreamModelName = textRequest.Model @@ -136,27 +136,16 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { } adaptor.Init(relayInfo, *textRequest) var requestBody io.Reader - if relayInfo.ApiType == relayconstant.APITypeOpenAI { - if isModelMapped { - jsonStr, err := json.Marshal(textRequest) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = c.Request.Body - } - } else { - convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) - } - jsonData, err := json.Marshal(convertedRequest) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonData) + + convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") resp, err := adaptor.DoRequest(c, relayInfo, requestBody)