diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index d549205..ce97755 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -214,10 +214,12 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel return false } }) - response := service.GenerateFinalUsageResponse(id, createdTime, model, usage) - err = service.ObjectData(c, response) - if err != nil { - common.SysError("send final response failed: " + err.Error()) + if info.ShouldIncludeUsage { + response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage) + err := service.ObjectData(c, response) + if err != nil { + common.SysError("send final response failed: " + err.Error()) + } } service.Done(c) err = resp.Body.Close() diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 9090146..9457f1e 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -349,13 +349,15 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens) } } - response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage) - err := service.ObjectData(c, response) - if err != nil { - common.SysError("send final response failed: " + err.Error()) + if info.ShouldIncludeUsage { + response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage) + err := service.ObjectData(c, response) + if err != nil { + common.SysError("send final response failed: " + err.Error()) + } } service.Done(c) - err = resp.Body.Close() + err := resp.Body.Close() if err != nil { return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 0c15793..00f01fd 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -7,7 +7,6 @@ import ( "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/ai360" @@ -20,8 +19,7 @@ import ( ) type Adaptor struct { - ChannelType int - SupportStreamOptions bool + ChannelType int } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { @@ -33,7 +31,6 @@ func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { a.ChannelType = info.ChannelType - a.SupportStreamOptions = info.SupportStreamOptions } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -81,17 +78,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen if request == nil { return nil, errors.New("request is nil") } - // 如果不支持StreamOptions,将StreamOptions设置为nil - if !a.SupportStreamOptions || !request.Stream { - request.StreamOptions = nil - } else { - // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions - if constant.ForceStreamOption { - request.StreamOptions = &dto.StreamOptions{ - IncludeUsage: true, - } - } - } return request, nil } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index cdad472..6794f0e 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -28,6 +28,7 @@ type RelayInfo struct { Organization string BaseUrl string SupportStreamOptions bool + ShouldIncludeUsage bool } func GenRelayInfo(c *gin.Context) *RelayInfo { diff --git a/relay/relay-text.go b/relay/relay-text.go index 7d3ad2c..6e74fbb 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -130,6 +130,22 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return openaiErr } + // 如果不支持StreamOptions,将StreamOptions设置为nil + if !relayInfo.SupportStreamOptions || !textRequest.Stream { + textRequest.StreamOptions = nil + } else { + // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions + if constant.ForceStreamOption { + textRequest.StreamOptions = &dto.StreamOptions{ + IncludeUsage: true, + } + } + } + + if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage { + relayInfo.ShouldIncludeUsage = textRequest.StreamOptions.IncludeUsage + } + adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)