mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 17:16:38 +08:00
fix: change GetAudioTokens to return float64 and update related functions
This commit is contained in:
parent
bcba9bf3a1
commit
b83e400297
@ -4,7 +4,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -33,7 +32,7 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAudioTokens returns the number of tokens in an audio file.
|
// GetAudioTokens returns the number of tokens in an audio file.
|
||||||
func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond int) (int, error) {
|
func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond float64) (float64, error) {
|
||||||
filename, err := SaveTmpFile("audio", audio)
|
filename, err := SaveTmpFile("audio", audio)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to save audio to temporary file")
|
return 0, errors.Wrap(err, "failed to save audio to temporary file")
|
||||||
@ -45,7 +44,7 @@ func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond int) (
|
|||||||
return 0, errors.Wrap(err, "failed to get audio tokens")
|
return 0, errors.Wrap(err, "failed to get audio tokens")
|
||||||
}
|
}
|
||||||
|
|
||||||
return int(math.Ceil(duration)) * tokensPerSecond, nil
|
return duration * tokensPerSecond, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAudioDuration returns the duration of an audio file in seconds.
|
// GetAudioDuration returns the duration of an audio file in seconds.
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/doubao"
|
"github.com/songquanpeng/one-api/relay/adaptor/doubao"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
|
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
|
||||||
|
@ -93,7 +93,9 @@ func CountTokenMessages(ctx context.Context,
|
|||||||
tokensPerMessage = 3
|
tokensPerMessage = 3
|
||||||
tokensPerName = 1
|
tokensPerName = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenNum := 0
|
tokenNum := 0
|
||||||
|
var totalAudioTokens float64
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
tokenNum += tokensPerMessage
|
tokenNum += tokensPerMessage
|
||||||
contents := message.ParseContent()
|
contents := message.ParseContent()
|
||||||
@ -117,17 +119,19 @@ func CountTokenMessages(ctx context.Context,
|
|||||||
logger.SysError("error decoding audio data: " + err.Error())
|
logger.SysError("error decoding audio data: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens, err := helper.GetAudioTokens(ctx,
|
audioTokens, err := helper.GetAudioTokens(ctx,
|
||||||
bytes.NewReader(audioData),
|
bytes.NewReader(audioData),
|
||||||
ratio.GetAudioPromptTokensPerSecond(actualModel))
|
ratio.GetAudioPromptTokensPerSecond(actualModel))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error counting audio tokens: " + err.Error())
|
logger.SysError("error counting audio tokens: " + err.Error())
|
||||||
} else {
|
} else {
|
||||||
tokenNum += tokens
|
totalAudioTokens += audioTokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tokenNum += int(math.Ceil(totalAudioTokens))
|
||||||
|
|
||||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||||
if message.Name != nil {
|
if message.Name != nil {
|
||||||
tokenNum += tokensPerName
|
tokenNum += tokensPerName
|
||||||
|
@ -3,7 +3,6 @@ package ratio
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
@ -391,7 +390,7 @@ var AudioPromptTokensPerSecond = map[string]float64{
|
|||||||
|
|
||||||
// GetAudioPromptTokensPerSecond returns the number of audio tokens per second
|
// GetAudioPromptTokensPerSecond returns the number of audio tokens per second
|
||||||
// for the given model.
|
// for the given model.
|
||||||
func GetAudioPromptTokensPerSecond(actualModelName string) int {
|
func GetAudioPromptTokensPerSecond(actualModelName string) float64 {
|
||||||
var v float64
|
var v float64
|
||||||
if tokensPerSecond, ok := AudioPromptTokensPerSecond[actualModelName]; ok {
|
if tokensPerSecond, ok := AudioPromptTokensPerSecond[actualModelName]; ok {
|
||||||
v = tokensPerSecond
|
v = tokensPerSecond
|
||||||
@ -399,7 +398,7 @@ func GetAudioPromptTokensPerSecond(actualModelName string) int {
|
|||||||
v = 10
|
v = 10
|
||||||
}
|
}
|
||||||
|
|
||||||
return int(math.Ceil(v))
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
var CompletionRatio = map[string]float64{
|
var CompletionRatio = map[string]float64{
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@ -33,7 +34,7 @@ type commonAudioRequest struct {
|
|||||||
File *multipart.FileHeader `form:"file" binding:"required"`
|
File *multipart.FileHeader `form:"file" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func countAudioTokens(c *gin.Context) (int, error) {
|
func countAudioTokens(c *gin.Context) (float64, error) {
|
||||||
body, err := common.GetRequestBody(c)
|
body, err := common.GetRequestBody(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.WithStack(err)
|
return 0, errors.WithStack(err)
|
||||||
@ -101,7 +102,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
return openai.ErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
preConsumedQuota = int64(float64(audioTokens) * ratio)
|
preConsumedQuota = int64(math.Ceil(audioTokens * ratio))
|
||||||
quota = preConsumedQuota
|
quota = preConsumedQuota
|
||||||
default:
|
default:
|
||||||
return openai.ErrorWrapper(errors.New("unexpected_relay_mode"), "unexpected_relay_mode", http.StatusInternalServerError)
|
return openai.ErrorWrapper(errors.New("unexpected_relay_mode"), "unexpected_relay_mode", http.StatusInternalServerError)
|
||||||
|
Loading…
Reference in New Issue
Block a user