mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 17:16:38 +08:00
feat: add audio processing helper functions and update Dockerfile
inspired by https://github.com/Laisky/one-api/pull/21
This commit is contained in:
parent
ddcd1295ff
commit
c1a0471e73
@ -33,7 +33,7 @@ RUN go build -trimpath -ldflags "-s -w -X 'github.com/songquanpeng/one-api/commo
|
|||||||
FROM debian:bullseye
|
FROM debian:bullseye
|
||||||
|
|
||||||
RUN apt-get update
|
RUN apt-get update
|
||||||
RUN apt-get install -y --no-install-recommends ca-certificates haveged tzdata \
|
RUN apt-get install -y --no-install-recommends ca-certificates haveged tzdata ffmpeg \
|
||||||
&& update-ca-certificates 2>/dev/null || true \
|
&& update-ca-certificates 2>/dev/null || true \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
40
common/helper/audio.go
Normal file
40
common/helper/audio.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package helper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
|
||||||
|
func SaveTmpFile(filename string, data io.Reader) (string, error) {
|
||||||
|
f, err := os.CreateTemp(os.TempDir(), filename)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrapf(err, "failed to create temporary file %s", filename)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(f, data)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.Name(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAudioDuration returns the duration of an audio file in seconds.
|
||||||
|
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
|
||||||
|
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
|
||||||
|
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
|
||||||
|
output, err := c.Output()
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||||
|
}
|
||||||
|
|
||||||
|
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
|
||||||
|
}
|
@ -338,6 +338,8 @@ var CompletionRatio = map[string]float64{
|
|||||||
// aws llama3
|
// aws llama3
|
||||||
"llama3-8b-8192(33)": 0.0006 / 0.0003,
|
"llama3-8b-8192(33)": 0.0006 / 0.0003,
|
||||||
"llama3-70b-8192(33)": 0.0035 / 0.00265,
|
"llama3-70b-8192(33)": 0.0035 / 0.00265,
|
||||||
|
// whisper
|
||||||
|
"whisper-1": 0, // only count input tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -7,15 +7,17 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/client"
|
"github.com/songquanpeng/one-api/common/client"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
|
||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
@ -27,6 +29,35 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TokensPerSecond = 1000 / 20 // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
func countAudioTokens(c *gin.Context) (int, error) {
|
||||||
|
body, err := common.GetRequestBody(c)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fp, err := os.CreateTemp("", "audio-*")
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
defer os.Remove(fp.Name())
|
||||||
|
|
||||||
|
_, err = io.Copy(fp, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
duration, err := helper.GetAudioDuration(c.Request.Context(), fp.Name())
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(math.Ceil(duration)) * TokensPerSecond, nil
|
||||||
|
}
|
||||||
|
|
||||||
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
meta := meta.GetByContext(c)
|
meta := meta.GetByContext(c)
|
||||||
@ -64,9 +95,19 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
case relaymode.AudioSpeech:
|
case relaymode.AudioSpeech:
|
||||||
preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio)
|
preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio)
|
||||||
quota = preConsumedQuota
|
quota = preConsumedQuota
|
||||||
default:
|
case relaymode.AudioTranscription,
|
||||||
preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio)
|
relaymode.AudioTranslation:
|
||||||
|
audioTokens, err := countAudioTokens(c)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
preConsumedQuota = int64(float64(audioTokens) * ratio)
|
||||||
|
quota = preConsumedQuota
|
||||||
|
default:
|
||||||
|
return openai.ErrorWrapper(errors.New("unexpected_relay_mode"), "unexpected_relay_mode", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
userQuota, err := model.CacheGetUserQuota(ctx, userId)
|
userQuota, err := model.CacheGetUserQuota(ctx, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
@ -140,7 +181,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
|
||||||
responseFormat := c.DefaultPostForm("response_format", "json")
|
// responseFormat := c.DefaultPostForm("response_format", "json")
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -173,47 +214,53 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if relayMode != relaymode.AudioSpeech {
|
// https://github.com/Laisky/one-api/pull/21
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
// Commenting out the following code because Whisper's transcription
|
||||||
if err != nil {
|
// only charges for the length of the input audio, not for the output.
|
||||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
// -------------------------------------
|
||||||
}
|
// if relayMode != relaymode.AudioSpeech {
|
||||||
err = resp.Body.Close()
|
// responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
// return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
// }
|
||||||
|
// err = resp.Body.Close()
|
||||||
|
// if err != nil {
|
||||||
|
// return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
|
// }
|
||||||
|
|
||||||
var openAIErr openai.SlimTextResponse
|
// var openAIErr openai.SlimTextResponse
|
||||||
if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
|
// if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
|
||||||
if openAIErr.Error.Message != "" {
|
// if openAIErr.Error.Message != "" {
|
||||||
return openai.ErrorWrapper(errors.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
|
// return openai.ErrorWrapper(errors.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
|
// var text string
|
||||||
|
// switch responseFormat {
|
||||||
|
// case "json":
|
||||||
|
// text, err = getTextFromJSON(responseBody)
|
||||||
|
// case "text":
|
||||||
|
// text, err = getTextFromText(responseBody)
|
||||||
|
// case "srt":
|
||||||
|
// text, err = getTextFromSRT(responseBody)
|
||||||
|
// case "verbose_json":
|
||||||
|
// text, err = getTextFromVerboseJSON(responseBody)
|
||||||
|
// case "vtt":
|
||||||
|
// text, err = getTextFromVTT(responseBody)
|
||||||
|
// default:
|
||||||
|
// return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
|
||||||
|
// }
|
||||||
|
// if err != nil {
|
||||||
|
// return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
|
||||||
|
// }
|
||||||
|
// quota = int64(openai.CountTokenText(text, audioModel))
|
||||||
|
// resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
// }
|
||||||
|
|
||||||
var text string
|
|
||||||
switch responseFormat {
|
|
||||||
case "json":
|
|
||||||
text, err = getTextFromJSON(responseBody)
|
|
||||||
case "text":
|
|
||||||
text, err = getTextFromText(responseBody)
|
|
||||||
case "srt":
|
|
||||||
text, err = getTextFromSRT(responseBody)
|
|
||||||
case "verbose_json":
|
|
||||||
text, err = getTextFromVerboseJSON(responseBody)
|
|
||||||
case "vtt":
|
|
||||||
text, err = getTextFromVTT(responseBody)
|
|
||||||
default:
|
|
||||||
return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
quota = int64(openai.CountTokenText(text, audioModel))
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
}
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return RelayErrorHandler(resp)
|
return RelayErrorHandler(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
succeed = true
|
succeed = true
|
||||||
quotaDelta := quota - preConsumedQuota
|
quotaDelta := quota - preConsumedQuota
|
||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
|
Loading…
Reference in New Issue
Block a user