feat: Support stream_options

This commit is contained in:
MartialBE
2024-05-26 19:58:15 +08:00
parent fa54ca7b50
commit eb260652b2
11 changed files with 188 additions and 31 deletions

View File

@@ -40,12 +40,26 @@ func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionReq
}
func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
streamOptions := request.StreamOptions
// 如果支持流式返回Usage 则需要更改配置:
if p.SupportStreamOptions {
request.StreamOptions = &types.StreamOptions{
IncludeUsage: true,
}
} else {
// 避免误传导致报错
request.StreamOptions = nil
}
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 恢复原来的配置
request.StreamOptions = streamOptions
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {

View File

@@ -40,6 +40,16 @@ func (p *GroqProvider) CreateChatCompletion(request *types.ChatCompletionRequest
}
func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
streamOptions := request.StreamOptions
// 如果支持流式返回Usage 则需要更改配置:
if p.SupportStreamOptions {
request.StreamOptions = &types.StreamOptions{
IncludeUsage: true,
}
} else {
// 避免误传导致报错
request.StreamOptions = nil
}
p.getChatRequestBody(request)
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
@@ -47,6 +57,9 @@ func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionR
}
defer req.Body.Close()
// 恢复原来的配置
request.StreamOptions = streamOptions
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {

View File

@@ -17,8 +17,9 @@ type OpenAIProviderFactory struct{}
type OpenAIProvider struct {
base.BaseProvider
IsAzure bool
BalanceAction bool
IsAzure bool
BalanceAction bool
SupportStreamOptions bool
}
// 创建 OpenAIProvider
@@ -33,7 +34,7 @@ func (f OpenAIProviderFactory) Create(channel *model.Channel) base.ProviderInter
func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider {
config := getOpenAIConfig(baseURL)
return &OpenAIProvider{
OpenAIProvider := &OpenAIProvider{
BaseProvider: base.BaseProvider{
Config: config,
Channel: channel,
@@ -42,6 +43,12 @@ func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvide
IsAzure: false,
BalanceAction: true,
}
if channel.Type == common.ChannelTypeOpenAI {
OpenAIProvider.SupportStreamOptions = true
}
return OpenAIProvider
}
func getOpenAIConfig(baseURL string) base.ProviderConfig {

View File

@@ -8,7 +8,6 @@ import (
"one-api/common/requester"
"one-api/types"
"strings"
"time"
)
type OpenAIStreamHandler struct {
@@ -58,12 +57,25 @@ func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionReque
}
func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
streamOptions := request.StreamOptions
// 如果支持流式返回Usage 则需要更改配置:
if p.SupportStreamOptions {
request.StreamOptions = &types.StreamOptions{
IncludeUsage: true,
}
} else {
// 避免误传导致报错
request.StreamOptions = nil
}
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 恢复原来的配置
request.StreamOptions = streamOptions
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
@@ -110,18 +122,23 @@ func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, dataChan chan s
}
if len(openaiResponse.Choices) == 0 {
if openaiResponse.Usage != nil {
*h.Usage = *openaiResponse.Usage
}
*rawLine = nil
return
}
dataChan <- string(*rawLine)
if h.isAzure {
// 阻塞 20ms
time.Sleep(20 * time.Millisecond)
if len(openaiResponse.Choices) > 0 && openaiResponse.Choices[0].Usage != nil {
*h.Usage = *openaiResponse.Choices[0].Usage
} else {
if h.Usage.TotalTokens == 0 {
h.Usage.TotalTokens = h.Usage.PromptTokens
}
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText
}
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText
}

View File

@@ -40,12 +40,25 @@ func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (ope
}
func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[string], errWithCode *types.OpenAIErrorWithStatusCode) {
streamOptions := request.StreamOptions
// 如果支持流式返回Usage 则需要更改配置:
if p.SupportStreamOptions {
request.StreamOptions = &types.StreamOptions{
IncludeUsage: true,
}
} else {
// 避免误传导致报错
request.StreamOptions = nil
}
req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 恢复原来的配置
request.StreamOptions = streamOptions
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
@@ -90,8 +103,19 @@ func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, dataChan
return
}
if len(openaiResponse.Choices) == 0 {
if openaiResponse.Usage != nil {
*h.Usage = *openaiResponse.Usage
}
*rawLine = nil
return
}
dataChan <- string(*rawLine)
if h.Usage.TotalTokens == 0 {
h.Usage.TotalTokens = h.Usage.PromptTokens
}
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText