feat: support stream_options

This commit is contained in:
CalciumIon 2024-07-08 01:27:57 +08:00
parent 20d71711d3
commit b0e234e8f5
11 changed files with 97 additions and 55 deletions

View File

@ -73,7 +73,7 @@
## 比原版One API多出的配置 ## 比原版One API多出的配置
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒 - `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true` 可选值为 `true``false` - `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true` 可选值为 `true``false`
- `FORCE_STREAM_OPTION`覆盖客户端stream_options参数请求上游返回流模式usage目前仅支持 `OpenAI` 渠道类型
## 部署 ## 部署
### 部署要求 ### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机) - 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)

View File

@ -6,3 +6,6 @@ import (
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30) var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true) var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
// ForceStreamOption 覆盖请求参数强制返回usage信息
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)

View File

@ -11,6 +11,7 @@ type GeneralOpenAIRequest struct {
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"` Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
@ -43,6 +44,10 @@ type OpenAIFunction struct {
Parameters any `json:"parameters,omitempty"` Parameters any `json:"parameters,omitempty"`
} }
type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
func (r GeneralOpenAIRequest) GetMaxTokens() int64 { func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
return int64(r.MaxTokens) return int64(r.MaxTokens)
} }

View File

@ -106,6 +106,7 @@ type ChatCompletionsStreamResponse struct {
type ChatCompletionsStreamResponseSimple struct { type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"` Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"`
} }
type CompletionsStreamResponse struct { type CompletionsStreamResponse struct {

View File

@ -59,8 +59,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info) err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
} else { } else {
if info.RelayMode == relayconstant.RelayModeEmbeddings { if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)

View File

@ -7,6 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/ai360" "one-api/relay/channel/ai360"
@ -20,6 +21,7 @@ import (
type Adaptor struct { type Adaptor struct {
ChannelType int ChannelType int
SupportStreamOptions bool
} }
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@ -31,6 +33,7 @@ func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequ
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
a.ChannelType = info.ChannelType a.ChannelType = info.ChannelType
a.SupportStreamOptions = info.SupportStreamOptions
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@ -78,6 +81,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
// 如果不支持StreamOptions将StreamOptions设置为nil
if !a.SupportStreamOptions {
request.StreamOptions = nil
} else {
// 如果支持StreamOptions且请求中没有设置StreamOptions根据配置文件设置StreamOptions
if constant.ForceStreamOption {
request.StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
}
}
return request, nil return request, nil
} }
@ -89,9 +103,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream { if info.IsStream {
var responseText string var responseText string
var toolCount int var toolCount int
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info) err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7 usage.CompletionTokens += toolCount * 7
}
} else { } else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }

View File

@ -18,9 +18,10 @@ import (
"time" "time"
) )
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string, int) { func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
//checkSensitive := constant.ShouldCheckCompletionSensitive() //checkSensitive := constant.ShouldCheckCompletionSensitive()
var responseTextBuilder strings.Builder var responseTextBuilder strings.Builder
var usage dto.Usage
toolCount := 0 toolCount := 0
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@ -62,17 +63,26 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
streamItems = append(streamItems, data) streamItems = append(streamItems, data)
} }
} }
// 计算token
streamResp := "[" + strings.Join(streamItems, ",") + "]" streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode { switch info.RelayMode {
case relayconstant.RelayModeChatCompletions: case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponseSimple var streamResponses []dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil { if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems { for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponseSimple var streamResponse dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil { if err == nil {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage.PromptTokens += streamResponse.Usage.PromptTokens
usage.CompletionTokens += streamResponse.Usage.CompletionTokens
usage.TotalTokens += streamResponse.Usage.TotalTokens
}
}
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString()) responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil { if choice.Delta.ToolCalls != nil {
@ -89,6 +99,13 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
} }
} else { } else {
for _, streamResponse := range streamResponses { for _, streamResponse := range streamResponses {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage.PromptTokens += streamResponse.Usage.PromptTokens
usage.CompletionTokens += streamResponse.Usage.CompletionTokens
usage.TotalTokens += streamResponse.Usage.TotalTokens
}
}
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString()) responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil { if choice.Delta.ToolCalls != nil {
@ -107,6 +124,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
var streamResponses []dto.CompletionsStreamResponse var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil { if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems { for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse var streamResponse dto.CompletionsStreamResponse
@ -159,10 +177,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
} }
wg.Wait() wg.Wait()
return nil, responseTextBuilder.String(), toolCount return nil, &usage, responseTextBuilder.String(), toolCount
} }
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {

View File

@ -55,8 +55,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info) err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }

View File

@ -57,9 +57,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream { if info.IsStream {
var responseText string var responseText string
var toolCount int var toolCount int
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info) err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7 usage.CompletionTokens += toolCount * 7
}
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }

View File

@ -27,6 +27,7 @@ type RelayInfo struct {
ApiKey string ApiKey string
Organization string Organization string
BaseUrl string BaseUrl string
SupportStreamOptions bool
} }
func GenRelayInfo(c *gin.Context) *RelayInfo { func GenRelayInfo(c *gin.Context) *RelayInfo {
@ -65,6 +66,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if info.ChannelType == common.ChannelTypeAzure { if info.ChannelType == common.ChannelTypeAzure {
info.ApiVersion = GetAPIVersion(c) info.ApiVersion = GetAPIVersion(c)
} }
if info.ChannelType == common.ChannelTypeOpenAI {
info.SupportStreamOptions = true
}
return info return info
} }

View File

@ -77,7 +77,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
// map model name // map model name
modelMapping := c.GetString("model_mapping") modelMapping := c.GetString("model_mapping")
isModelMapped := false //isModelMapped := false
if modelMapping != "" && modelMapping != "{}" { if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string) modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap) err := json.Unmarshal([]byte(modelMapping), &modelMap)
@ -87,7 +87,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if modelMap[textRequest.Model] != "" { if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model] textRequest.Model = modelMap[textRequest.Model]
// set upstream model name // set upstream model name
isModelMapped = true //isModelMapped = true
} }
} }
relayInfo.UpstreamModelName = textRequest.Model relayInfo.UpstreamModelName = textRequest.Model
@ -136,17 +136,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
} }
adaptor.Init(relayInfo, *textRequest) adaptor.Init(relayInfo, *textRequest)
var requestBody io.Reader var requestBody io.Reader
if relayInfo.ApiType == relayconstant.APITypeOpenAI {
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
} else {
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest) convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
@ -156,7 +146,6 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
} }
requestBody = bytes.NewBuffer(jsonData) requestBody = bytes.NewBuffer(jsonData)
}
statusCodeMappingStr := c.GetString("status_code_mapping") statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody) resp, err := adaptor.DoRequest(c, relayInfo, requestBody)