fix: audio transcription only charge for the length of audio duration

This commit is contained in:
Laisky.Cai
2025-01-08 05:00:33 +00:00
parent 3915ce9814
commit c6c8053ccc
9 changed files with 238 additions and 79 deletions

View File

@@ -5,17 +5,20 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"mime/multipart"
"net/http"
"os"
"strings"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
@@ -27,6 +30,53 @@ import (
"github.com/songquanpeng/one-api/relay/relaymode"
)
const (
TokensPerSecond = 1000 / 20 // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens
)
type commonAudioRequest struct {
File *multipart.FileHeader `form:"file" binding:"required"`
}
func countAudioTokens(c *gin.Context) (int, error) {
body, err := common.GetRequestBody(c)
if err != nil {
return 0, errors.WithStack(err)
}
reqBody := new(commonAudioRequest)
c.Request.Body = io.NopCloser(bytes.NewReader(body))
if err = c.ShouldBind(reqBody); err != nil {
return 0, errors.WithStack(err)
}
reqFp, err := reqBody.File.Open()
if err != nil {
return 0, errors.WithStack(err)
}
tmpFp, err := os.CreateTemp("", "audio-*")
if err != nil {
return 0, errors.WithStack(err)
}
defer os.Remove(tmpFp.Name())
_, err = io.Copy(tmpFp, reqFp)
if err != nil {
return 0, errors.WithStack(err)
}
if err = tmpFp.Close(); err != nil {
return 0, errors.WithStack(err)
}
duration, err := helper.GetAudioDuration(c.Request.Context(), tmpFp.Name())
if err != nil {
return 0, errors.WithStack(err)
}
return int(math.Ceil(duration)) * TokensPerSecond, nil
}
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := meta.GetByContext(c)
@@ -63,9 +113,19 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
case relaymode.AudioSpeech:
preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota
case relaymode.AudioTranscription,
relaymode.AudioTranslation:
audioTokens, err := countAudioTokens(c)
if err != nil {
return openai.ErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError)
}
preConsumedQuota = int64(float64(audioTokens) * ratio)
quota = preConsumedQuota
default:
preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio)
return openai.ErrorWrapper(errors.New("unexpected_relay_mode"), "unexpected_relay_mode", http.StatusInternalServerError)
}
userQuota, err := model.CacheGetUserQuota(ctx, userId)
if err != nil {
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
@@ -139,7 +199,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
responseFormat := c.DefaultPostForm("response_format", "json")
// responseFormat := c.DefaultPostForm("response_format", "json")
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
@@ -172,47 +232,53 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
if relayMode != relaymode.AudioSpeech {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
// https://github.com/Laisky/one-api/pull/21
// Commenting out the following code because Whisper's transcription
// only charges for the length of the input audio, not for the output.
// -------------------------------------
// if relayMode != relaymode.AudioSpeech {
// responseBody, err := io.ReadAll(resp.Body)
// if err != nil {
// return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
// }
// err = resp.Body.Close()
// if err != nil {
// return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
// }
var openAIErr openai.SlimTextResponse
if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
if openAIErr.Error.Message != "" {
return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
}
}
// var openAIErr openai.SlimTextResponse
// if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
// if openAIErr.Error.Message != "" {
// return openai.ErrorWrapper(errors.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
// }
// }
// var text string
// switch responseFormat {
// case "json":
// text, err = getTextFromJSON(responseBody)
// case "text":
// text, err = getTextFromText(responseBody)
// case "srt":
// text, err = getTextFromSRT(responseBody)
// case "verbose_json":
// text, err = getTextFromVerboseJSON(responseBody)
// case "vtt":
// text, err = getTextFromVTT(responseBody)
// default:
// return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
// }
// if err != nil {
// return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
// }
// quota = int64(openai.CountTokenText(text, audioModel))
// resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// }
var text string
switch responseFormat {
case "json":
text, err = getTextFromJSON(responseBody)
case "text":
text, err = getTextFromText(responseBody)
case "srt":
text, err = getTextFromSRT(responseBody)
case "verbose_json":
text, err = getTextFromVerboseJSON(responseBody)
case "vtt":
text, err = getTextFromVTT(responseBody)
default:
return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
}
if err != nil {
return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
}
quota = int64(openai.CountTokenText(text, audioModel))
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
if resp.StatusCode != http.StatusOK {
return RelayErrorHandler(resp)
}
succeed = true
quotaDelta := quota - preConsumedQuota
defer func(ctx context.Context) {