From 8de79382f023ac91a810e32e5d9ebaf4b92a8318 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Mon, 7 Oct 2024 17:18:11 +0800 Subject: [PATCH] feat: azure realtime (cherry picked from commit 75ff3d98f06103dc2df1f8817bd3fcbf433e0f20) --- relay/channel/openai/relay-openai.go | 26 +++++++++++++++----------- service/token_counter.go | 15 +++++++++++++++ 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 1aef14e..6b11aac 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -478,6 +478,18 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens + } else { + textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) + if err != nil { + errChan <- fmt.Errorf("error counting text token: %v", err) + return + } + log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken) + localUsage.TotalTokens += textToken + audioToken + info.IsFirstRequest = false + localUsage.InputTokens += textToken + audioToken + localUsage.InputTokenDetails.TextTokens += textToken + localUsage.InputTokenDetails.AudioTokens += audioToken } } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { realtimeSession := realtimeEvent.Session @@ -494,17 +506,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken) localUsage.TotalTokens += textToken + audioToken - - if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone { - info.IsFirstRequest = false - localUsage.InputTokens += textToken + audioToken - localUsage.InputTokenDetails.TextTokens += textToken - localUsage.InputTokenDetails.AudioTokens += audioToken - } else { - localUsage.OutputTokens += textToken + audioToken - localUsage.OutputTokenDetails.TextTokens += textToken - localUsage.OutputTokenDetails.AudioTokens += audioToken - } + localUsage.OutputTokens += textToken + audioToken + localUsage.OutputTokenDetails.TextTokens += textToken + localUsage.OutputTokenDetails.AudioTokens += audioToken } err = service.WssString(c, clientConn, string(message)) diff --git a/service/token_counter.go b/service/token_counter.go index 63eb712..17fbe0a 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -225,6 +225,21 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, return 0, 0, fmt.Errorf("error counting audio token: %v", err) } audioToken += atk + case dto.RealtimeEventConversationItemCreated: + if request.Item != nil { + switch request.Item.Type { + case "message": + for _, content := range request.Item.Content { + if content.Type == "input_text" { + tokens, err := CountTextToken(content.Text, model) + if err != nil { + return 0, 0, err + } + textToken += tokens + } + } + } + } case dto.RealtimeEventTypeResponseDone: // count tools token if !info.IsFirstRequest {