From 74f9006b404ebaabc021bfafac43c9c52b87312e Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Sun, 6 Oct 2024 14:13:41 +0800 Subject: [PATCH] feat: realtime (cherry picked from commit d4966246e68dbdcdab45ec5c5141362834d74425) --- common/model-ratio.go | 16 ++- dto/realtime.go | 12 +- relay/channel/claude/relay-claude.go | 2 +- relay/channel/cloudflare/relay_cloudflare.go | 2 +- relay/channel/dify/relay-dify.go | 2 +- relay/channel/openai/adaptor.go | 4 +- relay/channel/openai/relay-openai.go | 63 +++++++++- relay/channel/palm/relay-palm.go | 2 +- relay/common/relay_info.go | 8 ++ relay/relay-audio.go | 2 +- relay/websocket.go | 22 ++-- service/audio.go | 31 +++++ service/relay.go | 4 +- service/token_counter.go | 117 +++++++++++++------ service/usage_helpr.go | 2 +- 15 files changed, 227 insertions(+), 62 deletions(-) create mode 100644 service/audio.go diff --git a/common/model-ratio.go b/common/model-ratio.go index c037b8b..6eab850 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -432,9 +432,23 @@ func GetAudioCompletionRatio(name string) float64 { if strings.HasPrefix(name, "gpt-4o-realtime") { return 10 } - return 10 + return 2 } +//func GetAudioPricePerMinute(name string) float64 { +// if strings.HasPrefix(name, "gpt-4o-realtime") { +// return 0.06 +// } +// return 0.06 +//} +// +//func GetAudioCompletionPricePerMinute(name string) float64 { +// if strings.HasPrefix(name, "gpt-4o-realtime") { +// return 0.24 +// } +// return 0.24 +//} + func GetCompletionRatioMap() map[string]float64 { if CompletionRatio == nil { CompletionRatio = defaultCompletionRatio diff --git a/dto/realtime.go b/dto/realtime.go index c470730..cca99f3 100644 --- a/dto/realtime.go +++ b/dto/realtime.go @@ -5,10 +5,18 @@ const ( RealtimeEventTypeSessionUpdate = "session.update" RealtimeEventTypeConversationCreate = "conversation.item.create" RealtimeEventTypeResponseCreate = "response.create" + RealtimeEventInputAudioBufferAppend = "input_audio_buffer.append" ) const ( - RealtimeEventTypeResponseDone = "response.done" + RealtimeEventTypeResponseDone = "response.done" + RealtimeEventTypeSessionUpdated = "session.updated" + RealtimeEventTypeSessionCreated = "session.created" + RealtimeEventResponseAudioDelta = "response.audio.delta" + RealtimeEventResponseAudioTranscriptionDelta = "response.audio_transcript.delta" + RealtimeEventResponseFunctionCallArgumentsDelta = "response.function_call_arguments.delta" + RealtimeEventResponseFunctionCallArgumentsDone = "response.function_call_arguments.done" + RealtimeEventConversationItemCreated = "conversation.item.created" ) type RealtimeEvent struct { @@ -19,6 +27,8 @@ type RealtimeEvent struct { Item *RealtimeItem `json:"item,omitempty"` Error *OpenAIError `json:"error,omitempty"` Response *RealtimeResponse `json:"response,omitempty"` + Delta string `json:"delta,omitempty"` + Audio string `json:"audio,omitempty"` } type RealtimeResponse struct { diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 781b9a7..4c7f188 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -509,7 +509,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r }, nil } fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) - completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName) + completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) if err != nil { return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil } diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index 69d6b85..d21e524 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -149,7 +149,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName) + usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return nil, usage diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 66ba839..5df34d3 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -108,7 +108,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re } if usage.TotalTokens == 0 { usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText) + usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } return nil, usage diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 5ac0306..a663d15 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -47,8 +47,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { model_ := info.UpstreamModelName model_ = strings.Replace(model_, ".", "", -1) // https://github.com/songquanpeng/one-api/issues/67 - requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) + if info.RelayMode == constant.RelayModeRealtime { + requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion) + } return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil case common.ChannelTypeMiniMax: return minimax.GetRequestURL(info) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index d2eccb3..1aef14e 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "io" + "log" "net/http" "one-api/common" "one-api/constant" @@ -232,7 +233,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range simpleResponse.Choices { - ctkm, _ := service.CountTokenText(string(choice.Message.Content), model) + ctkm, _ := service.CountTextToken(string(choice.Message.Content), model) completionTokens += ctkm } simpleResponse.Usage = dto.Usage{ @@ -325,7 +326,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName) + usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return nil, usage } @@ -387,6 +388,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op errChan := make(chan error, 2) usage := &dto.RealtimeUsage{} + localUsage := &dto.RealtimeUsage{} go func() { for { @@ -403,6 +405,32 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op return } + realtimeEvent := &dto.RealtimeEvent{} + err = json.Unmarshal(message, realtimeEvent) + if err != nil { + errChan <- fmt.Errorf("error unmarshalling message: %v", err) + return + } + + if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate { + if realtimeEvent.Session != nil { + if realtimeEvent.Session.Tools != nil { + info.RealtimeTools = realtimeEvent.Session.Tools + } + } + } + + textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) + if err != nil { + errChan <- fmt.Errorf("error counting text token: %v", err) + return + } + log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken) + localUsage.TotalTokens += textToken + audioToken + localUsage.InputTokens += textToken + localUsage.InputTokenDetails.TextTokens += textToken + localUsage.InputTokenDetails.AudioTokens += audioToken + err = service.WssString(c, targetConn, string(message)) if err != nil { errChan <- fmt.Errorf("error writing to target: %v", err) @@ -451,6 +479,32 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens } + } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { + realtimeSession := realtimeEvent.Session + if realtimeSession != nil { + // update audio format + info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat) + info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat) + } + } else { + textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) + if err != nil { + errChan <- fmt.Errorf("error counting text token: %v", err) + return + } + log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken) + localUsage.TotalTokens += textToken + audioToken + + if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone { + info.IsFirstRequest = false + localUsage.InputTokens += textToken + audioToken + localUsage.InputTokenDetails.TextTokens += textToken + localUsage.InputTokenDetails.AudioTokens += audioToken + } else { + localUsage.OutputTokens += textToken + audioToken + localUsage.OutputTokenDetails.TextTokens += textToken + localUsage.OutputTokenDetails.AudioTokens += audioToken + } } err = service.WssString(c, clientConn, string(message)) @@ -475,5 +529,10 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op case <-c.Done(): } + // check usage total tokens, if 0, use local usage + + if usage.TotalTokens == 0 { + usage = localUsage + } return nil, usage } diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 47588a2..dfde59f 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model) + completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index bd29b9c..b43f917 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -4,6 +4,7 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "one-api/common" + "one-api/dto" "one-api/relay/constant" "strings" "time" @@ -35,11 +36,18 @@ type RelayInfo struct { ShouldIncludeUsage bool ClientWs *websocket.Conn TargetWs *websocket.Conn + InputAudioFormat string + OutputAudioFormat string + RealtimeTools []dto.RealTimeTool + IsFirstRequest bool } func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { info := GenRelayInfo(c) info.ClientWs = ws + info.InputAudioFormat = "pcm16" + info.OutputAudioFormat = "pcm16" + info.IsFirstRequest = true return info } diff --git a/relay/relay-audio.go b/relay/relay-audio.go index e1d0a70..b65f612 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -58,7 +58,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { promptTokens := 0 preConsumedTokens := common.PreConsumedQuota if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { - promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model) + promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model) if err != nil { return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) } diff --git a/relay/websocket.go b/relay/websocket.go index 5bd1e81..089805d 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -150,7 +150,7 @@ func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod quota := 0 if !usePrice { quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio)) - quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*completionRatio*audioCompletionRatio)) + quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio)) quota = int(math.Round(float64(quota) * ratio)) if ratio != 0 && quota <= 0 { @@ -215,16 +215,16 @@ func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod //} } -func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) { - var promptTokens int - var err error - switch info.RelayMode { - default: - promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName) - } - info.PromptTokens = promptTokens - return promptTokens, err -} +//func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) { +// var promptTokens int +// var err error +// switch info.RelayMode { +// default: +// promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName) +// } +// info.PromptTokens = promptTokens +// return promptTokens, err +//} //func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error { // var err error diff --git a/service/audio.go b/service/audio.go new file mode 100644 index 0000000..d558e96 --- /dev/null +++ b/service/audio.go @@ -0,0 +1,31 @@ +package service + +import ( + "encoding/base64" + "fmt" +) + +func parseAudio(audioBase64 string, format string) (duration float64, err error) { + audioData, err := base64.StdEncoding.DecodeString(audioBase64) + if err != nil { + return 0, fmt.Errorf("base64 decode error: %v", err) + } + + var samplesCount int + var sampleRate int + + switch format { + case "pcm16": + samplesCount = len(audioData) / 2 // 16位 = 2字节每样本 + sampleRate = 24000 // 24kHz + case "g711_ulaw", "g711_alaw": + samplesCount = len(audioData) // 8位 = 1字节每样本 + sampleRate = 8000 // 8kHz + default: + samplesCount = len(audioData) // 8位 = 1字节每样本 + sampleRate = 8000 // 8kHz + } + + duration = float64(samplesCount) / float64(sampleRate) + return duration, nil +} diff --git a/service/relay.go b/service/relay.go index 4b5ed36..6ffed1e 100644 --- a/service/relay.go +++ b/service/relay.go @@ -48,7 +48,7 @@ func WssString(c *gin.Context, ws *websocket.Conn, str string) error { common.LogError(c, "websocket connection is nil") return errors.New("websocket connection is nil") } - common.LogInfo(c, fmt.Sprintf("sending message: %s", str)) + //common.LogInfo(c, fmt.Sprintf("sending message: %s", str)) return ws.WriteMessage(1, []byte(str)) } @@ -61,7 +61,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { common.LogError(c, "websocket connection is nil") return errors.New("websocket connection is nil") } - common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData)) + //common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData)) return ws.WriteMessage(1, jsonData) } diff --git a/service/token_counter.go b/service/token_counter.go index e169a25..63eb712 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -11,6 +11,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + relaycommon "one-api/relay/common" "strings" "unicode/utf8" ) @@ -191,43 +192,55 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, return tkm, nil } -func CountTokenRealtime(request dto.RealtimeEvent, model string) (int, error) { - tkm := 0 - ratio := 1 - if request.Session != nil { - msgTokens, err := CountTokenText(request.Session.Instructions, model) - if err != nil { - return 0, err +func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) { + audioToken := 0 + textToken := 0 + switch request.Type { + case dto.RealtimeEventTypeSessionUpdate: + if request.Session != nil { + msgTokens, err := CountTextToken(request.Session.Instructions, model) + if err != nil { + return 0, 0, err + } + textToken += msgTokens } - ratio = len(request.Session.Modalities) - tkm += msgTokens - if request.Session.Tools != nil { - toolsData, _ := json.Marshal(request.Session.Tools) - var openaiTools []dto.OpenAITools - err := json.Unmarshal(toolsData, &openaiTools) - if err != nil { - return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())) - } - countStr := "" - for _, tool := range openaiTools { - countStr = tool.Function.Name - if tool.Function.Description != "" { - countStr += tool.Function.Description - } - if tool.Function.Parameters != nil { - countStr += fmt.Sprintf("%v", tool.Function.Parameters) + case dto.RealtimeEventResponseAudioDelta: + // count audio token + atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat) + if err != nil { + return 0, 0, fmt.Errorf("error counting audio token: %v", err) + } + audioToken += atk + case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta: + // count text token + tkm, err := CountTextToken(request.Delta, model) + if err != nil { + return 0, 0, fmt.Errorf("error counting text token: %v", err) + } + textToken += tkm + case dto.RealtimeEventInputAudioBufferAppend: + // count audio token + atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat) + if err != nil { + return 0, 0, fmt.Errorf("error counting audio token: %v", err) + } + audioToken += atk + case dto.RealtimeEventTypeResponseDone: + // count tools token + if !info.IsFirstRequest { + if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 { + for _, tool := range info.RealtimeTools { + toolTokens, err := CountTokenInput(tool, model) + if err != nil { + return 0, 0, err + } + textToken += 8 + textToken += toolTokens } } - toolTokens, err := CountTokenInput(countStr, model) - if err != nil { - return 0, err - } - tkm += 8 - tkm += toolTokens } } - tkm *= ratio - return tkm, nil + return textToken, audioToken, nil } func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) { @@ -287,13 +300,13 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, func CountTokenInput(input any, model string) (int, error) { switch v := input.(type) { case string: - return CountTokenText(v, model) + return CountTextToken(v, model) case []string: text := "" for _, s := range v { text += s } - return CountTokenText(text, model) + return CountTextToken(text, model) } return CountTokenInput(fmt.Sprintf("%v", input), model) } @@ -315,16 +328,44 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, return tokens } -func CountAudioToken(text string, model string) (int, error) { +func CountTTSToken(text string, model string) (int, error) { if strings.HasPrefix(model, "tts") { return utf8.RuneCountInString(text), nil } else { - return CountTokenText(text, model) + return CountTextToken(text, model) } } -// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量 -func CountTokenText(text string, model string) (int, error) { +func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) { + if audioBase64 == "" { + return 0, nil + } + duration, err := parseAudio(audioBase64, audioFormat) + if err != nil { + return 0, err + } + return int(duration / 60 * 100 / 0.06), nil +} + +func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) { + if audioBase64 == "" { + return 0, nil + } + duration, err := parseAudio(audioBase64, audioFormat) + if err != nil { + return 0, err + } + return int(duration / 60 * 200 / 0.24), nil +} + +//func CountAudioToken(sec float64, audioType string) { +// if audioType == "input" { +// +// } +//} + +// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量 +func CountTextToken(text string, model string) (int, error) { var err error tokenEncoder := getTokenEncoder(model) return getTokenNum(tokenEncoder, text), err diff --git a/service/usage_helpr.go b/service/usage_helpr.go index d2fa102..c52e1e1 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -19,7 +19,7 @@ import ( func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) { usage := &dto.Usage{} usage.PromptTokens = promptTokens - ctkm, err := CountTokenText(responseText, modeName) + ctkm, err := CountTextToken(responseText, modeName) usage.CompletionTokens = ctkm usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return usage, err