feat(audio): count whisper-1 quota by audio duration

This commit is contained in:
WqyJh
2024-01-17 18:16:15 +08:00
parent eed9f5fdf0
commit 89799d84c5
7 changed files with 119 additions and 59 deletions

View File

@@ -7,17 +7,45 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"math"
"net/http"
"one-api/common"
"one-api/common/audio"
"one-api/model"
"one-api/relay/channel/openai"
"one-api/relay/constant"
"one-api/relay/util"
"os"
"strings"
"github.com/gin-gonic/gin"
)
const (
TokensPerSecond = 1000 / 20 // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens
)
func countAudioTokens(req *http.Request) (int, error) {
cloned := common.CloneRequest(req)
defer cloned.Body.Close()
file, header, err := cloned.FormFile("file")
if err != nil {
return 0, err
}
defer file.Close()
f, err := common.SaveTmpFile(header.Filename, file)
if err != nil {
return 0, err
}
defer os.Remove(f)
duration, err := audio.GetAudioDuration(cloned.Context(), f)
if err != nil {
return 0, err
}
return int(math.Ceil(duration)) * TokensPerSecond, nil
}
func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
audioModel := "whisper-1"
@@ -28,8 +56,15 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
group := c.GetString("group")
tokenName := c.GetString("token_name")
var inputTokens int
var ttsRequest openai.TextToSpeechRequest
if relayMode == constant.RelayModeAudioSpeech {
modelRatio := common.GetModelRatio(audioModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
var quota int
var preConsumedQuota int
switch relayMode {
case constant.RelayModeAudioSpeech:
// Read JSON
err := common.UnmarshalBodyReusable(c, &ttsRequest)
// Check if JSON is valid
@@ -41,20 +76,17 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
if len(ttsRequest.Input) > 4096 {
return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
}
}
modelRatio := common.GetModelRatio(audioModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
var quota int
var preConsumedQuota int
switch relayMode {
case constant.RelayModeAudioSpeech:
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota
inputTokens = len(ttsRequest.Input)
default:
preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio)
// whisper-1 audio transcription
audioTokens, err := countAudioTokens(c.Request)
if err != nil {
return openai.ErrorWrapper(err, "get_audio_duration_failed", http.StatusInternalServerError)
}
inputTokens = audioTokens
}
preConsumedQuota = int(float64(inputTokens) * ratio)
quota = preConsumedQuota
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
@@ -112,7 +144,6 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
responseFormat := c.DefaultPostForm("response_format", "json")
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
@@ -145,44 +176,6 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
if relayMode != constant.RelayModeAudioSpeech {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
var openAIErr openai.SlimTextResponse
if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
if openAIErr.Error.Message != "" {
return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
}
}
var text string
switch responseFormat {
case "json":
text, err = getTextFromJSON(responseBody)
case "text":
text, err = getTextFromText(responseBody)
case "srt":
text, err = getTextFromSRT(responseBody)
case "verbose_json":
text, err = getTextFromVerboseJSON(responseBody)
case "vtt":
text, err = getTextFromVTT(responseBody)
default:
return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
}
if err != nil {
return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
}
quota = openai.CountTokenText(text, audioModel)
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
if resp.StatusCode != http.StatusOK {
if preConsumedQuota > 0 {
// we need to roll back the pre-consumed quota

View File

@@ -136,11 +136,13 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
if quotaDelta != 0 {
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
}
err = model.CacheUpdateUserQuota(userId)
err := model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}