From 02545e4856fb0c6dcae6a53f66437fbee6e133d6 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Mon, 8 Jul 2024 19:46:45 +0800 Subject: [PATCH] fix: baidu max_output_tokens (close #353) --- relay/channel/baidu/dto.go | 2 +- relay/channel/baidu/relay-baidu.go | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/relay/channel/baidu/dto.go b/relay/channel/baidu/dto.go index 5168c11..f0c6f04 100644 --- a/relay/channel/baidu/dto.go +++ b/relay/channel/baidu/dto.go @@ -19,7 +19,7 @@ type BaiduChatRequest struct { System string `json:"system,omitempty"` DisableSearch bool `json:"disable_search,omitempty"` EnableCitation bool `json:"enable_citation,omitempty"` - MaxOutputTokens int `json:"max_output_tokens,omitempty"` + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` UserId string `json:"user_id,omitempty"` } diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index d9e93c7..e313316 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -23,14 +23,17 @@ var baiduTokenStore sync.Map func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest { baiduRequest := BaiduChatRequest{ - Temperature: request.Temperature, - TopP: request.TopP, - PenaltyScore: request.FrequencyPenalty, - Stream: request.Stream, - DisableSearch: false, - EnableCitation: false, - MaxOutputTokens: int(request.MaxTokens), - UserId: request.User, + Temperature: request.Temperature, + TopP: request.TopP, + PenaltyScore: request.FrequencyPenalty, + Stream: request.Stream, + DisableSearch: false, + EnableCitation: false, + UserId: request.User, + } + if request.MaxTokens != 0 { + maxTokens := int(request.MaxTokens) + baiduRequest.MaxOutputTokens = &maxTokens } for _, message := range request.Messages { if message.Role == "system" {