From b83e4002975815745b5ff9bcf77db00135e4129b Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Sun, 26 Jan 2025 12:17:31 +0000 Subject: [PATCH] fix: change GetAudioTokens to return float64 and update related functions --- common/helper/audio.go | 5 ++--- relay/adaptor/openai/adaptor.go | 1 + relay/adaptor/openai/token.go | 8 ++++++-- relay/billing/ratio/model.go | 5 ++--- relay/controller/audio.go | 5 +++-- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/common/helper/audio.go b/common/helper/audio.go index e5689904..e31afc44 100644 --- a/common/helper/audio.go +++ b/common/helper/audio.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "io" - "math" "os" "os/exec" "strconv" @@ -33,7 +32,7 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) { } // GetAudioTokens returns the number of tokens in an audio file. -func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond int) (int, error) { +func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond float64) (float64, error) { filename, err := SaveTmpFile("audio", audio) if err != nil { return 0, errors.Wrap(err, "failed to save audio to temporary file") @@ -45,7 +44,7 @@ func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond int) ( return 0, errors.Wrap(err, "failed to get audio tokens") } - return int(math.Ceil(duration)) * tokensPerSecond, nil + return duration * tokensPerSecond, nil } // GetAudioDuration returns the duration of an audio file in seconds. diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index e688d5fa..b54c71ee 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/doubao" "github.com/songquanpeng/one-api/relay/adaptor/minimax" diff --git a/relay/adaptor/openai/token.go b/relay/adaptor/openai/token.go index 1287e44b..4cfdd2b7 100644 --- a/relay/adaptor/openai/token.go +++ b/relay/adaptor/openai/token.go @@ -93,7 +93,9 @@ func CountTokenMessages(ctx context.Context, tokensPerMessage = 3 tokensPerName = 1 } + tokenNum := 0 + var totalAudioTokens float64 for _, message := range messages { tokenNum += tokensPerMessage contents := message.ParseContent() @@ -117,17 +119,19 @@ func CountTokenMessages(ctx context.Context, logger.SysError("error decoding audio data: " + err.Error()) } - tokens, err := helper.GetAudioTokens(ctx, + audioTokens, err := helper.GetAudioTokens(ctx, bytes.NewReader(audioData), ratio.GetAudioPromptTokensPerSecond(actualModel)) if err != nil { logger.SysError("error counting audio tokens: " + err.Error()) } else { - tokenNum += tokens + totalAudioTokens += audioTokens } } } + tokenNum += int(math.Ceil(totalAudioTokens)) + tokenNum += getTokenNum(tokenEncoder, message.Role) if message.Name != nil { tokenNum += tokensPerName diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index d52de788..c992260b 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -3,7 +3,6 @@ package ratio import ( "encoding/json" "fmt" - "math" "strings" "github.com/songquanpeng/one-api/common/logger" @@ -391,7 +390,7 @@ var AudioPromptTokensPerSecond = map[string]float64{ // GetAudioPromptTokensPerSecond returns the number of audio tokens per second // for the given model. -func GetAudioPromptTokensPerSecond(actualModelName string) int { +func GetAudioPromptTokensPerSecond(actualModelName string) float64 { var v float64 if tokensPerSecond, ok := AudioPromptTokensPerSecond[actualModelName]; ok { v = tokensPerSecond @@ -399,7 +398,7 @@ func GetAudioPromptTokensPerSecond(actualModelName string) int { v = 10 } - return int(math.Ceil(v)) + return v } var CompletionRatio = map[string]float64{ diff --git a/relay/controller/audio.go b/relay/controller/audio.go index b90666d3..32cc0d38 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "math" "mime/multipart" "net/http" "strings" @@ -33,7 +34,7 @@ type commonAudioRequest struct { File *multipart.FileHeader `form:"file" binding:"required"` } -func countAudioTokens(c *gin.Context) (int, error) { +func countAudioTokens(c *gin.Context) (float64, error) { body, err := common.GetRequestBody(c) if err != nil { return 0, errors.WithStack(err) @@ -101,7 +102,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError) } - preConsumedQuota = int64(float64(audioTokens) * ratio) + preConsumedQuota = int64(math.Ceil(audioTokens * ratio)) quota = preConsumedQuota default: return openai.ErrorWrapper(errors.New("unexpected_relay_mode"), "unexpected_relay_mode", http.StatusInternalServerError)