From c1a0471e739f3d02a697ef32f9abfe4fae05ec23 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Wed, 8 Jan 2025 02:35:38 +0000 Subject: [PATCH] feat: add audio processing helper functions and update Dockerfile inspired by https://github.com/Laisky/one-api/pull/21 --- Dockerfile | 2 +- common/helper/audio.go | 40 +++++++++++ relay/billing/ratio/model.go | 2 + relay/controller/audio.go | 125 ++++++++++++++++++++++++----------- 4 files changed, 129 insertions(+), 40 deletions(-) create mode 100644 common/helper/audio.go diff --git a/Dockerfile b/Dockerfile index 4f775135..f587bdfe 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,7 +33,7 @@ RUN go build -trimpath -ldflags "-s -w -X 'github.com/songquanpeng/one-api/commo FROM debian:bullseye RUN apt-get update -RUN apt-get install -y --no-install-recommends ca-certificates haveged tzdata \ +RUN apt-get install -y --no-install-recommends ca-certificates haveged tzdata ffmpeg \ && update-ca-certificates 2>/dev/null || true \ && rm -rf /var/lib/apt/lists/* diff --git a/common/helper/audio.go b/common/helper/audio.go new file mode 100644 index 00000000..9db62f42 --- /dev/null +++ b/common/helper/audio.go @@ -0,0 +1,40 @@ +package helper + +import ( + "bytes" + "context" + "io" + "os" + "os/exec" + "strconv" + + "github.com/pkg/errors" +) + +// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string. +func SaveTmpFile(filename string, data io.Reader) (string, error) { + f, err := os.CreateTemp(os.TempDir(), filename) + if err != nil { + return "", errors.Wrapf(err, "failed to create temporary file %s", filename) + } + defer f.Close() + + _, err = io.Copy(f, data) + if err != nil { + return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename) + } + + return f.Name(), nil +} + +// GetAudioDuration returns the duration of an audio file in seconds. +func GetAudioDuration(ctx context.Context, filename string) (float64, error) { + // ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}} + c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename) + output, err := c.Output() + if err != nil { + return 0, errors.Wrap(err, "failed to get audio duration") + } + + return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64) +} diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index a8587ba4..dca3280f 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -338,6 +338,8 @@ var CompletionRatio = map[string]float64{ // aws llama3 "llama3-8b-8192(33)": 0.0006 / 0.0003, "llama3-70b-8192(33)": 0.0035 / 0.00265, + // whisper + "whisper-1": 0, // only count input tokens } var ( diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 541ee1a5..095aa29c 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -7,15 +7,17 @@ import ( "encoding/json" "fmt" "io" + "math" "net/http" + "os" "strings" "github.com/gin-gonic/gin" "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/client" - "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/adaptor/openai" @@ -27,6 +29,35 @@ import ( "github.com/songquanpeng/one-api/relay/relaymode" ) +const ( + TokensPerSecond = 1000 / 20 // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens +) + +func countAudioTokens(c *gin.Context) (int, error) { + body, err := common.GetRequestBody(c) + if err != nil { + return 0, errors.WithStack(err) + } + + fp, err := os.CreateTemp("", "audio-*") + if err != nil { + return 0, errors.WithStack(err) + } + defer os.Remove(fp.Name()) + + _, err = io.Copy(fp, bytes.NewReader(body)) + if err != nil { + return 0, errors.WithStack(err) + } + + duration, err := helper.GetAudioDuration(c.Request.Context(), fp.Name()) + if err != nil { + return 0, errors.WithStack(err) + } + + return int(math.Ceil(duration)) * TokensPerSecond, nil +} + func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() meta := meta.GetByContext(c) @@ -64,9 +95,19 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus case relaymode.AudioSpeech: preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) quota = preConsumedQuota + case relaymode.AudioTranscription, + relaymode.AudioTranslation: + audioTokens, err := countAudioTokens(c) + if err != nil { + return openai.ErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError) + } + + preConsumedQuota = int64(float64(audioTokens) * ratio) + quota = preConsumedQuota default: - preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio) + return openai.ErrorWrapper(errors.New("unexpected_relay_mode"), "unexpected_relay_mode", http.StatusInternalServerError) } + userQuota, err := model.CacheGetUserQuota(ctx, userId) if err != nil { return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) @@ -140,7 +181,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) } c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) - responseFormat := c.DefaultPostForm("response_format", "json") + // responseFormat := c.DefaultPostForm("response_format", "json") req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { @@ -173,47 +214,53 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - if relayMode != relaymode.AudioSpeech { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } + // https://github.com/Laisky/one-api/pull/21 + // Commenting out the following code because Whisper's transcription + // only charges for the length of the input audio, not for the output. + // ------------------------------------- + // if relayMode != relaymode.AudioSpeech { + // responseBody, err := io.ReadAll(resp.Body) + // if err != nil { + // return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + // } + // err = resp.Body.Close() + // if err != nil { + // return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + // } - var openAIErr openai.SlimTextResponse - if err = json.Unmarshal(responseBody, &openAIErr); err == nil { - if openAIErr.Error.Message != "" { - return openai.ErrorWrapper(errors.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) - } - } + // var openAIErr openai.SlimTextResponse + // if err = json.Unmarshal(responseBody, &openAIErr); err == nil { + // if openAIErr.Error.Message != "" { + // return openai.ErrorWrapper(errors.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) + // } + // } + + // var text string + // switch responseFormat { + // case "json": + // text, err = getTextFromJSON(responseBody) + // case "text": + // text, err = getTextFromText(responseBody) + // case "srt": + // text, err = getTextFromSRT(responseBody) + // case "verbose_json": + // text, err = getTextFromVerboseJSON(responseBody) + // case "vtt": + // text, err = getTextFromVTT(responseBody) + // default: + // return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) + // } + // if err != nil { + // return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) + // } + // quota = int64(openai.CountTokenText(text, audioModel)) + // resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // } - var text string - switch responseFormat { - case "json": - text, err = getTextFromJSON(responseBody) - case "text": - text, err = getTextFromText(responseBody) - case "srt": - text, err = getTextFromSRT(responseBody) - case "verbose_json": - text, err = getTextFromVerboseJSON(responseBody) - case "vtt": - text, err = getTextFromVTT(responseBody) - default: - return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) - } - if err != nil { - return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) - } - quota = int64(openai.CountTokenText(text, audioModel)) - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - } if resp.StatusCode != http.StatusOK { return RelayErrorHandler(resp) } + succeed = true quotaDelta := quota - preConsumedQuota defer func(ctx context.Context) {