mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-17 05:33:42 +08:00
✨ feat: Support stream_options
This commit is contained in:
@@ -40,12 +40,26 @@ func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionReq
|
||||
}
|
||||
|
||||
func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
streamOptions := request.StreamOptions
|
||||
// 如果支持流式返回Usage 则需要更改配置:
|
||||
if p.SupportStreamOptions {
|
||||
request.StreamOptions = &types.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
} else {
|
||||
// 避免误传导致报错
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 恢复原来的配置
|
||||
request.StreamOptions = streamOptions
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
|
||||
@@ -40,6 +40,16 @@ func (p *GroqProvider) CreateChatCompletion(request *types.ChatCompletionRequest
|
||||
}
|
||||
|
||||
func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
streamOptions := request.StreamOptions
|
||||
// 如果支持流式返回Usage 则需要更改配置:
|
||||
if p.SupportStreamOptions {
|
||||
request.StreamOptions = &types.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
} else {
|
||||
// 避免误传导致报错
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
p.getChatRequestBody(request)
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
@@ -47,6 +57,9 @@ func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionR
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 恢复原来的配置
|
||||
request.StreamOptions = streamOptions
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
|
||||
@@ -17,8 +17,9 @@ type OpenAIProviderFactory struct{}
|
||||
|
||||
type OpenAIProvider struct {
|
||||
base.BaseProvider
|
||||
IsAzure bool
|
||||
BalanceAction bool
|
||||
IsAzure bool
|
||||
BalanceAction bool
|
||||
SupportStreamOptions bool
|
||||
}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
@@ -33,7 +34,7 @@ func (f OpenAIProviderFactory) Create(channel *model.Channel) base.ProviderInter
|
||||
func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider {
|
||||
config := getOpenAIConfig(baseURL)
|
||||
|
||||
return &OpenAIProvider{
|
||||
OpenAIProvider := &OpenAIProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
Config: config,
|
||||
Channel: channel,
|
||||
@@ -42,6 +43,12 @@ func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvide
|
||||
IsAzure: false,
|
||||
BalanceAction: true,
|
||||
}
|
||||
|
||||
if channel.Type == common.ChannelTypeOpenAI {
|
||||
OpenAIProvider.SupportStreamOptions = true
|
||||
}
|
||||
|
||||
return OpenAIProvider
|
||||
}
|
||||
|
||||
func getOpenAIConfig(baseURL string) base.ProviderConfig {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OpenAIStreamHandler struct {
|
||||
@@ -58,12 +57,25 @@ func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionReque
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
streamOptions := request.StreamOptions
|
||||
// 如果支持流式返回Usage 则需要更改配置:
|
||||
if p.SupportStreamOptions {
|
||||
request.StreamOptions = &types.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
} else {
|
||||
// 避免误传导致报错
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 恢复原来的配置
|
||||
request.StreamOptions = streamOptions
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
@@ -110,18 +122,23 @@ func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, dataChan chan s
|
||||
}
|
||||
|
||||
if len(openaiResponse.Choices) == 0 {
|
||||
if openaiResponse.Usage != nil {
|
||||
*h.Usage = *openaiResponse.Usage
|
||||
}
|
||||
*rawLine = nil
|
||||
return
|
||||
}
|
||||
|
||||
dataChan <- string(*rawLine)
|
||||
|
||||
if h.isAzure {
|
||||
// 阻塞 20ms
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
if len(openaiResponse.Choices) > 0 && openaiResponse.Choices[0].Usage != nil {
|
||||
*h.Usage = *openaiResponse.Choices[0].Usage
|
||||
} else {
|
||||
if h.Usage.TotalTokens == 0 {
|
||||
h.Usage.TotalTokens = h.Usage.PromptTokens
|
||||
}
|
||||
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
|
||||
h.Usage.CompletionTokens += countTokenText
|
||||
h.Usage.TotalTokens += countTokenText
|
||||
}
|
||||
|
||||
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
|
||||
h.Usage.CompletionTokens += countTokenText
|
||||
h.Usage.TotalTokens += countTokenText
|
||||
}
|
||||
|
||||
@@ -40,12 +40,25 @@ func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (ope
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[string], errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
streamOptions := request.StreamOptions
|
||||
// 如果支持流式返回Usage 则需要更改配置:
|
||||
if p.SupportStreamOptions {
|
||||
request.StreamOptions = &types.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
} else {
|
||||
// 避免误传导致报错
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 恢复原来的配置
|
||||
request.StreamOptions = streamOptions
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
@@ -90,8 +103,19 @@ func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, dataChan
|
||||
return
|
||||
}
|
||||
|
||||
if len(openaiResponse.Choices) == 0 {
|
||||
if openaiResponse.Usage != nil {
|
||||
*h.Usage = *openaiResponse.Usage
|
||||
}
|
||||
*rawLine = nil
|
||||
return
|
||||
}
|
||||
|
||||
dataChan <- string(*rawLine)
|
||||
|
||||
if h.Usage.TotalTokens == 0 {
|
||||
h.Usage.TotalTokens = h.Usage.PromptTokens
|
||||
}
|
||||
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
|
||||
h.Usage.CompletionTokens += countTokenText
|
||||
h.Usage.TotalTokens += countTokenText
|
||||
|
||||
Reference in New Issue
Block a user