From be64408a250f60bca8d5e49171be71ed5a943300 Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Thu, 10 Oct 2024 00:15:27 +0800 Subject: [PATCH] =?UTF-8?q?fix(realtime):=20=E4=BF=AE=E5=A4=8Dws=20?= =?UTF-8?q?=E6=8F=A1=E6=89=8B=E5=A4=B1=E8=B4=A5=E3=80=81=E8=AE=A1=E8=B4=B9?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit 618dffc43fd5a5f4065944db87761f9ee18e44d3) --- controller/relay.go | 1 + relay/channel/openai/adaptor.go | 27 +++++++++++++++++++++------ relay/channel/openai/relay-openai.go | 4 ++++ service/quota.go | 6 ++---- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index fe65d96..c2e7523 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -145,6 +145,7 @@ func Relay(c *gin.Context) { } var upgrader = websocket.Upgrader{ + Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol CheckOrigin: func(r *http.Request) bool { return true // 允许跨域 }, diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index a663d15..def0850 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -63,18 +63,33 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { - channel.SetupApiRequestHeader(info, c, req) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, header) if info.ChannelType == common.ChannelTypeAzure { - req.Set("api-key", info.ApiKey) + header.Set("api-key", info.ApiKey) return nil } if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization { - req.Set("OpenAI-Organization", info.Organization) + header.Set("OpenAI-Organization", info.Organization) } - req.Set("Authorization", "Bearer "+info.ApiKey) if info.RelayMode == constant.RelayModeRealtime { - req.Set("openai-beta", "realtime=v1") + swp := c.Request.Header.Get("Sec-WebSocket-Protocol") + if swp != "" { + items := []string{ + "realtime", + "openai-insecure-api-key." + info.ApiKey, + "openai-beta.realtime-v1", + } + header.Set("Sec-WebSocket-Protocol", strings.Join(items, ",")) + //req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key")) + //req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions")) + //req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version")) + } else { + header.Set("openai-beta", "realtime=v1") + header.Set("Authorization", "Bearer "+info.ApiKey) + } + } else { + header.Set("Authorization", "Bearer "+info.ApiKey) } //if info.ChannelType == common.ChannelTypeOpenRouter { // req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 400e111..d172525 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -483,7 +483,10 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op errChan <- fmt.Errorf("error consume usage: %v", err) return } + // 本次计费完成,清除 usage = &dto.RealtimeUsage{} + + localUsage = &dto.RealtimeUsage{} } else { textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) if err != nil { @@ -501,6 +504,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op errChan <- fmt.Errorf("error consume usage: %v", err) return } + // 本次计费完成,清除 localUsage = &dto.RealtimeUsage{} // print now usage } diff --git a/service/quota.go b/service/quota.go index 9a3c542..7dd49b9 100644 --- a/service/quota.go +++ b/service/quota.go @@ -78,10 +78,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod quota := 0 if !usePrice { - quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio)) - quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio)) - - quota = int(math.Round(float64(quota) * ratio)) + quota = int(math.Round(float64(textInputTokens)*ratio + float64(textOutTokens)*ratio*completionRatio)) + quota += int(math.Round(float64(audioInputTokens)*ratio*audioRatio + float64(audioOutTokens)*ratio*audioRatio*audioCompletionRatio)) if ratio != 0 && quota <= 0 { quota = 1 }