Compare commits

...

3 Commits

Author SHA1 Message Date
CalciumIon
b0e234e8f5 feat: support stream_options 2024-07-08 01:27:57 +08:00
CalciumIon
20d71711d3 feat: add env DIFY_DEBUG 2024-07-07 02:24:51 +08:00
CalciumIon
4246c4cdc1 fix: streaming timeout 2024-07-07 01:09:56 +08:00
13 changed files with 116 additions and 59 deletions

View File

@@ -72,7 +72,8 @@
## 比原版One API多出的配置
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true` 可选值为 `true``false`
- `FORCE_STREAM_OPTION`覆盖客户端stream_options参数请求上游返回流模式usage目前仅支持 `OpenAI` 渠道类型
## 部署
### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)

View File

@@ -24,3 +24,15 @@ func GetEnvOrDefaultString(env string, defaultValue string) string {
}
return os.Getenv(env)
}
func GetEnvOrDefaultBool(env string, defaultValue bool) bool {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
b, err := strconv.ParseBool(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue))
return defaultValue
}
return b
}

View File

@@ -5,3 +5,7 @@ import (
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
// ForceStreamOption 覆盖请求参数强制返回usage信息
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)

View File

@@ -11,6 +11,7 @@ type GeneralOpenAIRequest struct {
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
@@ -43,6 +44,10 @@ type OpenAIFunction struct {
Parameters any `json:"parameters,omitempty"`
}
type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
return int64(r.MaxTokens)
}

View File

@@ -106,6 +106,7 @@ type ChatCompletionsStreamResponse struct {
type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"`
}
type CompletionsStreamResponse struct {

View File

@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
@@ -48,9 +49,9 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
Model: "dify",
}
var choice dto.ChatCompletionsStreamResponseChoice
if difyResponse.Event == "workflow_started" {
if constant.DifyDebug && difyResponse.Event == "workflow_started" {
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
} else if difyResponse.Event == "node_started" {
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
} else if difyResponse.Event == "message" {
choice.Delta.SetContentString(difyResponse.Answer)

View File

@@ -59,8 +59,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)

View File

@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/ai360"
@@ -19,7 +20,8 @@ import (
)
type Adaptor struct {
ChannelType int
ChannelType int
SupportStreamOptions bool
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -31,6 +33,7 @@ func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequ
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
a.ChannelType = info.ChannelType
a.SupportStreamOptions = info.SupportStreamOptions
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -78,6 +81,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
if request == nil {
return nil, errors.New("request is nil")
}
// 如果不支持StreamOptions将StreamOptions设置为nil
if !a.SupportStreamOptions {
request.StreamOptions = nil
} else {
// 如果支持StreamOptions且请求中没有设置StreamOptions根据配置文件设置StreamOptions
if constant.ForceStreamOption {
request.StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
}
}
return request, nil
}
@@ -89,9 +103,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
var toolCount int
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -18,9 +18,10 @@ import (
"time"
)
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string, int) {
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
//checkSensitive := constant.ShouldCheckCompletionSensitive()
var responseTextBuilder strings.Builder
var usage dto.Usage
toolCount := 0
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -62,17 +63,26 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
streamItems = append(streamItems, data)
}
}
// 计算token
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage.PromptTokens += streamResponse.Usage.PromptTokens
usage.CompletionTokens += streamResponse.Usage.CompletionTokens
usage.TotalTokens += streamResponse.Usage.TotalTokens
}
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
@@ -89,6 +99,13 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
} else {
for _, streamResponse := range streamResponses {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage.PromptTokens += streamResponse.Usage.PromptTokens
usage.CompletionTokens += streamResponse.Usage.CompletionTokens
usage.TotalTokens += streamResponse.Usage.TotalTokens
}
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
@@ -107,6 +124,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
@@ -133,7 +151,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}()
service.SetEventStreamHeaders(c)
isFirst := true
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout))
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
defer ticker.Stop()
c.Stream(func(w io.Writer) bool {
select {
@@ -145,7 +163,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
isFirst = false
info.FirstResponseTime = time.Now()
}
ticker.Reset(time.Duration(constant.StreamingTimeout))
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
@@ -159,10 +177,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
}
wg.Wait()
return nil, responseTextBuilder.String(), toolCount
return nil, &usage, responseTextBuilder.String(), toolCount
}
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {

View File

@@ -55,8 +55,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -57,9 +57,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
var toolCount int
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -9,24 +9,25 @@ import (
)
type RelayInfo struct {
ChannelType int
ChannelId int
TokenId int
UserId int
Group string
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
ApiType int
IsStream bool
RelayMode int
UpstreamModelName string
RequestURLPath string
ApiVersion string
PromptTokens int
ApiKey string
Organization string
BaseUrl string
ChannelType int
ChannelId int
TokenId int
UserId int
Group string
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
ApiType int
IsStream bool
RelayMode int
UpstreamModelName string
RequestURLPath string
ApiVersion string
PromptTokens int
ApiKey string
Organization string
BaseUrl string
SupportStreamOptions bool
}
func GenRelayInfo(c *gin.Context) *RelayInfo {
@@ -65,6 +66,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if info.ChannelType == common.ChannelTypeAzure {
info.ApiVersion = GetAPIVersion(c)
}
if info.ChannelType == common.ChannelTypeOpenAI {
info.SupportStreamOptions = true
}
return info
}

View File

@@ -77,7 +77,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
@@ -87,7 +87,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model]
// set upstream model name
isModelMapped = true
//isModelMapped = true
}
}
relayInfo.UpstreamModelName = textRequest.Model
@@ -136,27 +136,16 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
}
adaptor.Init(relayInfo, *textRequest)
var requestBody io.Reader
if relayInfo.ApiType == relayconstant.APITypeOpenAI {
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
} else {
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)