fix: change GetAudioTokens to return float64 and update related functions

This commit is contained in:
Laisky.Cai 2025-01-26 12:17:31 +00:00
parent bcba9bf3a1
commit b83e400297
5 changed files with 14 additions and 10 deletions

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"io" "io"
"math"
"os" "os"
"os/exec" "os/exec"
"strconv" "strconv"
@ -33,7 +32,7 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
} }
// GetAudioTokens returns the number of tokens in an audio file. // GetAudioTokens returns the number of tokens in an audio file.
func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond int) (int, error) { func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond float64) (float64, error) {
filename, err := SaveTmpFile("audio", audio) filename, err := SaveTmpFile("audio", audio)
if err != nil { if err != nil {
return 0, errors.Wrap(err, "failed to save audio to temporary file") return 0, errors.Wrap(err, "failed to save audio to temporary file")
@ -45,7 +44,7 @@ func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond int) (
return 0, errors.Wrap(err, "failed to get audio tokens") return 0, errors.Wrap(err, "failed to get audio tokens")
} }
return int(math.Ceil(duration)) * tokensPerSecond, nil return duration * tokensPerSecond, nil
} }
// GetAudioDuration returns the duration of an audio file in seconds. // GetAudioDuration returns the duration of an audio file in seconds.

View File

@ -8,6 +8,7 @@ import (
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/doubao" "github.com/songquanpeng/one-api/relay/adaptor/doubao"
"github.com/songquanpeng/one-api/relay/adaptor/minimax" "github.com/songquanpeng/one-api/relay/adaptor/minimax"

View File

@ -93,7 +93,9 @@ func CountTokenMessages(ctx context.Context,
tokensPerMessage = 3 tokensPerMessage = 3
tokensPerName = 1 tokensPerName = 1
} }
tokenNum := 0 tokenNum := 0
var totalAudioTokens float64
for _, message := range messages { for _, message := range messages {
tokenNum += tokensPerMessage tokenNum += tokensPerMessage
contents := message.ParseContent() contents := message.ParseContent()
@ -117,17 +119,19 @@ func CountTokenMessages(ctx context.Context,
logger.SysError("error decoding audio data: " + err.Error()) logger.SysError("error decoding audio data: " + err.Error())
} }
tokens, err := helper.GetAudioTokens(ctx, audioTokens, err := helper.GetAudioTokens(ctx,
bytes.NewReader(audioData), bytes.NewReader(audioData),
ratio.GetAudioPromptTokensPerSecond(actualModel)) ratio.GetAudioPromptTokensPerSecond(actualModel))
if err != nil { if err != nil {
logger.SysError("error counting audio tokens: " + err.Error()) logger.SysError("error counting audio tokens: " + err.Error())
} else { } else {
tokenNum += tokens totalAudioTokens += audioTokens
} }
} }
} }
tokenNum += int(math.Ceil(totalAudioTokens))
tokenNum += getTokenNum(tokenEncoder, message.Role) tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil { if message.Name != nil {
tokenNum += tokensPerName tokenNum += tokensPerName

View File

@ -3,7 +3,6 @@ package ratio
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"math"
"strings" "strings"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
@ -391,7 +390,7 @@ var AudioPromptTokensPerSecond = map[string]float64{
// GetAudioPromptTokensPerSecond returns the number of audio tokens per second // GetAudioPromptTokensPerSecond returns the number of audio tokens per second
// for the given model. // for the given model.
func GetAudioPromptTokensPerSecond(actualModelName string) int { func GetAudioPromptTokensPerSecond(actualModelName string) float64 {
var v float64 var v float64
if tokensPerSecond, ok := AudioPromptTokensPerSecond[actualModelName]; ok { if tokensPerSecond, ok := AudioPromptTokensPerSecond[actualModelName]; ok {
v = tokensPerSecond v = tokensPerSecond
@ -399,7 +398,7 @@ func GetAudioPromptTokensPerSecond(actualModelName string) int {
v = 10 v = 10
} }
return int(math.Ceil(v)) return v
} }
var CompletionRatio = map[string]float64{ var CompletionRatio = map[string]float64{

View File

@ -7,6 +7,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"strings" "strings"
@ -33,7 +34,7 @@ type commonAudioRequest struct {
File *multipart.FileHeader `form:"file" binding:"required"` File *multipart.FileHeader `form:"file" binding:"required"`
} }
func countAudioTokens(c *gin.Context) (int, error) { func countAudioTokens(c *gin.Context) (float64, error) {
body, err := common.GetRequestBody(c) body, err := common.GetRequestBody(c)
if err != nil { if err != nil {
return 0, errors.WithStack(err) return 0, errors.WithStack(err)
@ -101,7 +102,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError)
} }
preConsumedQuota = int64(float64(audioTokens) * ratio) preConsumedQuota = int64(math.Ceil(audioTokens * ratio))
quota = preConsumedQuota quota = preConsumedQuota
default: default:
return openai.ErrorWrapper(errors.New("unexpected_relay_mode"), "unexpected_relay_mode", http.StatusInternalServerError) return openai.ErrorWrapper(errors.New("unexpected_relay_mode"), "unexpected_relay_mode", http.StatusInternalServerError)