From c6c8053ccce7d23814126be00d76d70e60047a21 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Wed, 8 Jan 2025 05:00:33 +0000 Subject: [PATCH 1/6] fix: audio transcription only charge for the length of audio duration --- Dockerfile | 4 +- common/helper/audio.go | 40 ++++++++++ common/helper/audio_test.go | 37 +++++++++ go.mod | 21 +++-- go.sum | 38 ++++----- relay/adaptor/openai/model.go | 24 +++++- relay/billing/ratio/model.go | 2 + relay/controller/audio.go | 146 ++++++++++++++++++++++++---------- router/web.go | 5 +- 9 files changed, 238 insertions(+), 79 deletions(-) create mode 100644 common/helper/audio.go create mode 100644 common/helper/audio_test.go diff --git a/Dockerfile b/Dockerfile index ade561e4..72fbd08d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -35,10 +35,10 @@ FROM alpine RUN apk update \ && apk upgrade \ - && apk add --no-cache ca-certificates tzdata \ + && apk add --no-cache ca-certificates tzdata ffmpeg \ && update-ca-certificates 2>/dev/null || true COPY --from=builder2 /build/one-api / EXPOSE 3000 WORKDIR /data -ENTRYPOINT ["/one-api"] \ No newline at end of file +ENTRYPOINT ["/one-api"] diff --git a/common/helper/audio.go b/common/helper/audio.go new file mode 100644 index 00000000..9db62f42 --- /dev/null +++ b/common/helper/audio.go @@ -0,0 +1,40 @@ +package helper + +import ( + "bytes" + "context" + "io" + "os" + "os/exec" + "strconv" + + "github.com/pkg/errors" +) + +// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string. +func SaveTmpFile(filename string, data io.Reader) (string, error) { + f, err := os.CreateTemp(os.TempDir(), filename) + if err != nil { + return "", errors.Wrapf(err, "failed to create temporary file %s", filename) + } + defer f.Close() + + _, err = io.Copy(f, data) + if err != nil { + return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename) + } + + return f.Name(), nil +} + +// GetAudioDuration returns the duration of an audio file in seconds. +func GetAudioDuration(ctx context.Context, filename string) (float64, error) { + // ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}} + c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename) + output, err := c.Output() + if err != nil { + return 0, errors.Wrap(err, "failed to get audio duration") + } + + return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64) +} diff --git a/common/helper/audio_test.go b/common/helper/audio_test.go new file mode 100644 index 00000000..90f334a3 --- /dev/null +++ b/common/helper/audio_test.go @@ -0,0 +1,37 @@ +package helper + +import ( + "context" + "io" + "net/http" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetAudioDuration(t *testing.T) { + t.Run("should return correct duration for a valid audio file", func(t *testing.T) { + tmpFile, err := os.CreateTemp("", "test_audio*.mp3") + require.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + // download test audio file + resp, err := http.Get("https://s3.laisky.com/uploads/2025/01/audio-sample.m4a") + require.NoError(t, err) + defer resp.Body.Close() + + _, err = io.Copy(tmpFile, resp.Body) + require.NoError(t, err) + require.NoError(t, tmpFile.Close()) + + duration, err := GetAudioDuration(context.Background(), tmpFile.Name()) + require.NoError(t, err) + require.Equal(t, duration, 3.904) + }) + + t.Run("should return an error for a non-existent file", func(t *testing.T) { + _, err := GetAudioDuration(context.Background(), "non_existent_file.mp3") + require.Error(t, err) + }) +} diff --git a/go.mod b/go.mod index 2106cf0f..136546c1 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.31.0 golang.org/x/image v0.18.0 + golang.org/x/sync v0.10.0 google.golang.org/api v0.187.0 gorm.io/driver/mysql v1.5.6 gorm.io/driver/postgres v1.5.7 @@ -38,29 +39,27 @@ require ( cloud.google.com/go/auth v0.6.1 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect - filippo.io/edwards25519 v1.1.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 // indirect github.com/aws/smithy-go v1.20.2 // indirect github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/dlclark/regexp2 v1.11.0 // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-sql-driver/mysql v1.8.1 // indirect - github.com/goccy/go-json v0.10.3 // indirect + github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/s2a-go v0.1.7 // indirect @@ -71,9 +70,8 @@ require ( github.com/gorilla/securecookie v1.1.2 // indirect github.com/gorilla/sessions v1.2.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect - github.com/jackc/pgx/v5 v5.5.5 // indirect - github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect @@ -82,7 +80,7 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-sqlite3 v1.14.22 // indirect + github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect @@ -99,13 +97,12 @@ require ( golang.org/x/arch v0.8.0 // indirect golang.org/x/net v0.26.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect - golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/text v0.21.0 // indirect golang.org/x/time v0.5.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect - google.golang.org/grpc v1.64.1 // indirect + google.golang.org/grpc v1.64.0 // indirect google.golang.org/protobuf v1.34.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index c98f1965..e04bad1f 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,6 @@ cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2Qx cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/iam v1.1.10 h1:ZSAr64oEhQSClwBL670MsJAW5/RLiC6kfw3Bqmd5ZDI= cloud.google.com/go/iam v1.1.10/go.mod h1:iEgMq62sg8zx446GCaijmA2Miwg5o3UbO+nI47WHJps= -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo= github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= @@ -29,8 +27,8 @@ github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1 github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= @@ -43,16 +41,15 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= -github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= -github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw= @@ -81,11 +78,10 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= -github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= -github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= -github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= -github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -134,12 +130,10 @@ github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/ github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= -github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= -github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= -github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= +github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= @@ -163,8 +157,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= -github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -282,8 +276,8 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= -google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= +google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= +google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/relay/adaptor/openai/model.go b/relay/adaptor/openai/model.go index 4c974de4..39e87262 100644 --- a/relay/adaptor/openai/model.go +++ b/relay/adaptor/openai/model.go @@ -1,6 +1,10 @@ package openai -import "github.com/songquanpeng/one-api/relay/model" +import ( + "mime/multipart" + + "github.com/songquanpeng/one-api/relay/model" +) type TextContent struct { Type string `json:"type,omitempty"` @@ -71,6 +75,24 @@ type TextToSpeechRequest struct { ResponseFormat string `json:"response_format"` } +type AudioTranscriptionRequest struct { + File *multipart.FileHeader `form:"file" binding:"required"` + Model string `form:"model" binding:"required"` + Language string `form:"language"` + Prompt string `form:"prompt"` + ReponseFormat string `form:"response_format" binding:"oneof=json text srt verbose_json vtt"` + Temperature float64 `form:"temperature"` + TimestampGranularity []string `form:"timestamp_granularity"` +} + +type AudioTranslationRequest struct { + File *multipart.FileHeader `form:"file" binding:"required"` + Model string `form:"model" binding:"required"` + Prompt string `form:"prompt"` + ResponseFormat string `form:"response_format" binding:"oneof=json text srt verbose_json vtt"` + Temperature float64 `form:"temperature"` +} + type UsageOrResponseText struct { *model.Usage ResponseText string diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index f83aa70c..d1720a99 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -337,6 +337,8 @@ var CompletionRatio = map[string]float64{ // aws llama3 "llama3-8b-8192(33)": 0.0006 / 0.0003, "llama3-70b-8192(33)": 0.0035 / 0.00265, + // whisper + "whisper-1": 0, // only count input tokens } var ( diff --git a/relay/controller/audio.go b/relay/controller/audio.go index e3d57b1e..bc756f65 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -5,17 +5,20 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" + "math" + "mime/multipart" "net/http" + "os" "strings" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/client" - "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/adaptor/openai" @@ -27,6 +30,53 @@ import ( "github.com/songquanpeng/one-api/relay/relaymode" ) +const ( + TokensPerSecond = 1000 / 20 // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens +) + +type commonAudioRequest struct { + File *multipart.FileHeader `form:"file" binding:"required"` +} + +func countAudioTokens(c *gin.Context) (int, error) { + body, err := common.GetRequestBody(c) + if err != nil { + return 0, errors.WithStack(err) + } + + reqBody := new(commonAudioRequest) + c.Request.Body = io.NopCloser(bytes.NewReader(body)) + if err = c.ShouldBind(reqBody); err != nil { + return 0, errors.WithStack(err) + } + + reqFp, err := reqBody.File.Open() + if err != nil { + return 0, errors.WithStack(err) + } + + tmpFp, err := os.CreateTemp("", "audio-*") + if err != nil { + return 0, errors.WithStack(err) + } + defer os.Remove(tmpFp.Name()) + + _, err = io.Copy(tmpFp, reqFp) + if err != nil { + return 0, errors.WithStack(err) + } + if err = tmpFp.Close(); err != nil { + return 0, errors.WithStack(err) + } + + duration, err := helper.GetAudioDuration(c.Request.Context(), tmpFp.Name()) + if err != nil { + return 0, errors.WithStack(err) + } + + return int(math.Ceil(duration)) * TokensPerSecond, nil +} + func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() meta := meta.GetByContext(c) @@ -63,9 +113,19 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus case relaymode.AudioSpeech: preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) quota = preConsumedQuota + case relaymode.AudioTranscription, + relaymode.AudioTranslation: + audioTokens, err := countAudioTokens(c) + if err != nil { + return openai.ErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError) + } + + preConsumedQuota = int64(float64(audioTokens) * ratio) + quota = preConsumedQuota default: - preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio) + return openai.ErrorWrapper(errors.New("unexpected_relay_mode"), "unexpected_relay_mode", http.StatusInternalServerError) } + userQuota, err := model.CacheGetUserQuota(ctx, userId) if err != nil { return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) @@ -139,7 +199,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) } c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) - responseFormat := c.DefaultPostForm("response_format", "json") + // responseFormat := c.DefaultPostForm("response_format", "json") req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { @@ -172,47 +232,53 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - if relayMode != relaymode.AudioSpeech { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } + // https://github.com/Laisky/one-api/pull/21 + // Commenting out the following code because Whisper's transcription + // only charges for the length of the input audio, not for the output. + // ------------------------------------- + // if relayMode != relaymode.AudioSpeech { + // responseBody, err := io.ReadAll(resp.Body) + // if err != nil { + // return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + // } + // err = resp.Body.Close() + // if err != nil { + // return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + // } - var openAIErr openai.SlimTextResponse - if err = json.Unmarshal(responseBody, &openAIErr); err == nil { - if openAIErr.Error.Message != "" { - return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) - } - } + // var openAIErr openai.SlimTextResponse + // if err = json.Unmarshal(responseBody, &openAIErr); err == nil { + // if openAIErr.Error.Message != "" { + // return openai.ErrorWrapper(errors.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) + // } + // } + + // var text string + // switch responseFormat { + // case "json": + // text, err = getTextFromJSON(responseBody) + // case "text": + // text, err = getTextFromText(responseBody) + // case "srt": + // text, err = getTextFromSRT(responseBody) + // case "verbose_json": + // text, err = getTextFromVerboseJSON(responseBody) + // case "vtt": + // text, err = getTextFromVTT(responseBody) + // default: + // return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) + // } + // if err != nil { + // return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) + // } + // quota = int64(openai.CountTokenText(text, audioModel)) + // resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // } - var text string - switch responseFormat { - case "json": - text, err = getTextFromJSON(responseBody) - case "text": - text, err = getTextFromText(responseBody) - case "srt": - text, err = getTextFromSRT(responseBody) - case "verbose_json": - text, err = getTextFromVerboseJSON(responseBody) - case "vtt": - text, err = getTextFromVTT(responseBody) - default: - return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) - } - if err != nil { - return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) - } - quota = int64(openai.CountTokenText(text, audioModel)) - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - } if resp.StatusCode != http.StatusOK { return RelayErrorHandler(resp) } + succeed = true quotaDelta := quota - preConsumedQuota defer func(ctx context.Context) { diff --git a/router/web.go b/router/web.go index 3c9b4643..ebfc2ae1 100644 --- a/router/web.go +++ b/router/web.go @@ -3,6 +3,9 @@ package router import ( "embed" "fmt" + "net/http" + "strings" + "github.com/gin-contrib/gzip" "github.com/gin-contrib/static" "github.com/gin-gonic/gin" @@ -10,8 +13,6 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/middleware" - "net/http" - "strings" ) func SetWebRouter(router *gin.Engine, buildFS embed.FS) { From 2fc6caaae5e93f2efc22cc3c67df02aa16f74137 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Tue, 14 Jan 2025 06:38:07 +0000 Subject: [PATCH 2/6] feat: support gpt-4o-audio --- common/helper/audio.go | 25 +++- common/helper/audio_test.go | 18 +++ middleware/recover.go | 8 +- relay/adaptor/openai/adaptor.go | 21 +++ relay/adaptor/openai/constants.go | 1 + relay/adaptor/openai/main.go | 22 ++- relay/adaptor/openai/token.go | 84 ++++++----- relay/billing/ratio/model.go | 235 +++++++++++++++++++----------- relay/controller/audio.go | 31 +--- relay/controller/helper.go | 6 +- relay/controller/text.go | 4 +- relay/model/general.go | 61 ++++---- relay/model/message.go | 58 +++++++- relay/model/misc.go | 25 +++- 14 files changed, 401 insertions(+), 198 deletions(-) diff --git a/common/helper/audio.go b/common/helper/audio.go index 9db62f42..e5689904 100644 --- a/common/helper/audio.go +++ b/common/helper/audio.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "math" "os" "os/exec" "strconv" @@ -13,7 +14,11 @@ import ( // SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string. func SaveTmpFile(filename string, data io.Reader) (string, error) { - f, err := os.CreateTemp(os.TempDir(), filename) + if data == nil { + return "", errors.New("data is nil") + } + + f, err := os.CreateTemp("", "*-"+filename) if err != nil { return "", errors.Wrapf(err, "failed to create temporary file %s", filename) } @@ -27,6 +32,22 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) { return f.Name(), nil } +// GetAudioTokens returns the number of tokens in an audio file. +func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond int) (int, error) { + filename, err := SaveTmpFile("audio", audio) + if err != nil { + return 0, errors.Wrap(err, "failed to save audio to temporary file") + } + defer os.Remove(filename) + + duration, err := GetAudioDuration(ctx, filename) + if err != nil { + return 0, errors.Wrap(err, "failed to get audio tokens") + } + + return int(math.Ceil(duration)) * tokensPerSecond, nil +} + // GetAudioDuration returns the duration of an audio file in seconds. func GetAudioDuration(ctx context.Context, filename string) (float64, error) { // ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}} @@ -36,5 +57,7 @@ func GetAudioDuration(ctx context.Context, filename string) (float64, error) { return 0, errors.Wrap(err, "failed to get audio duration") } + // Actually gpt-4-audio calculates tokens with 0.1s precision, + // while whisper calculates tokens with 1s precision return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64) } diff --git a/common/helper/audio_test.go b/common/helper/audio_test.go index 90f334a3..15f55bbb 100644 --- a/common/helper/audio_test.go +++ b/common/helper/audio_test.go @@ -35,3 +35,21 @@ func TestGetAudioDuration(t *testing.T) { require.Error(t, err) }) } + +func TestGetAudioTokens(t *testing.T) { + t.Run("should return correct tokens for a valid audio file", func(t *testing.T) { + // download test audio file + resp, err := http.Get("https://s3.laisky.com/uploads/2025/01/audio-sample.m4a") + require.NoError(t, err) + defer resp.Body.Close() + + tokens, err := GetAudioTokens(context.Background(), resp.Body, 50) + require.NoError(t, err) + require.Equal(t, tokens, 200) + }) + + t.Run("should return an error for a non-existent file", func(t *testing.T) { + _, err := GetAudioTokens(context.Background(), nil, 1) + require.Error(t, err) + }) +} diff --git a/middleware/recover.go b/middleware/recover.go index cfc3f827..a690c77b 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -14,11 +14,11 @@ func RelayPanicRecover() gin.HandlerFunc { defer func() { if err := recover(); err != nil { ctx := c.Request.Context() - logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err)) - logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) - logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path)) + logger.Errorf(ctx, "panic detected: %v", err) + logger.Errorf(ctx, "stacktrace from panic: %s", string(debug.Stack())) + logger.Errorf(ctx, "request: %s %s", c.Request.Method, c.Request.URL.Path) body, _ := common.GetRequestBody(c) - logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body))) + logger.Errorf(ctx, "request body: %s", string(body)) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err), diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 6946e402..21966262 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -82,6 +82,27 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } request.StreamOptions.IncludeUsage = true } + + // o1/o1-mini/o1-preview do not support system prompt and max_tokens + if strings.HasPrefix(request.Model, "o1") { + request.MaxTokens = 0 + request.Messages = func(raw []model.Message) (filtered []model.Message) { + for i := range raw { + if raw[i].Role != "system" { + filtered = append(filtered, raw[i]) + } + } + + return + }(request.Messages) + } + + if request.Stream && strings.HasPrefix(request.Model, "gpt-4o-audio") { + // TODO: Since it is not clear how to implement billing in stream mode, + // it is temporarily not supported + return nil, errors.New("stream mode is not supported for gpt-4o-audio") + } + return request, nil } diff --git a/relay/adaptor/openai/constants.go b/relay/adaptor/openai/constants.go index 8a643bc6..2c34284f 100644 --- a/relay/adaptor/openai/constants.go +++ b/relay/adaptor/openai/constants.go @@ -12,6 +12,7 @@ var ModelList = []string{ "gpt-4o-2024-11-20", "chatgpt-4o-latest", "gpt-4o-mini", "gpt-4o-mini-2024-07-18", + "gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-12-17", "gpt-4o-audio-preview-2024-10-01", "gpt-4-vision-preview", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 97080738..095a6adb 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -5,15 +5,16 @@ import ( "bytes" "encoding/json" "io" + "math" "net/http" "strings" - "github.com/songquanpeng/one-api/common/render" - "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/render" + "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" ) @@ -96,6 +97,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E return nil, responseText, usage } +// Handler handles the non-stream response from OpenAI API func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { var textResponse SlimTextResponse responseBody, err := io.ReadAll(resp.Body) @@ -146,6 +148,22 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, } + } else { + // Convert the more expensive audio tokens to uniformly priced text tokens + textResponse.Usage.PromptTokens = textResponse.CompletionTokensDetails.TextTokens + + int(math.Ceil( + float64(textResponse.CompletionTokensDetails.AudioTokens)* + ratio.GetAudioPromptRatio(modelName), + )) + textResponse.Usage.CompletionTokens = textResponse.CompletionTokensDetails.TextTokens + + int(math.Ceil( + float64(textResponse.CompletionTokensDetails.AudioTokens)* + ratio.GetAudioPromptRatio(modelName)* + ratio.GetAudioCompletionRatio(modelName), + )) + textResponse.Usage.TotalTokens = textResponse.Usage.PromptTokens + + textResponse.Usage.CompletionTokens } + return nil, &textResponse.Usage } diff --git a/relay/adaptor/openai/token.go b/relay/adaptor/openai/token.go index 7c8468b9..1287e44b 100644 --- a/relay/adaptor/openai/token.go +++ b/relay/adaptor/openai/token.go @@ -1,16 +1,22 @@ package openai import ( - "errors" + "bytes" + "context" + "encoding/base64" "fmt" - "github.com/pkoukk/tiktoken-go" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/image" - "github.com/songquanpeng/one-api/common/logger" - billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" - "github.com/songquanpeng/one-api/relay/model" "math" "strings" + + "github.com/pkg/errors" + "github.com/pkoukk/tiktoken-go" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/image" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/billing/ratio" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/model" ) // tokenEncoderMap won't grow after initialization @@ -70,8 +76,9 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } -func CountTokenMessages(messages []model.Message, model string) int { - tokenEncoder := getTokenEncoder(model) +func CountTokenMessages(ctx context.Context, + messages []model.Message, actualModel string) int { + tokenEncoder := getTokenEncoder(actualModel) // Reference: // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb // https://github.com/pkoukk/tiktoken-go/issues/6 @@ -79,7 +86,7 @@ func CountTokenMessages(messages []model.Message, model string) int { // Every message follows <|start|>{role/name}\n{content}<|end|>\n var tokensPerMessage int var tokensPerName int - if model == "gpt-3.5-turbo-0301" { + if actualModel == "gpt-3.5-turbo-0301" { tokensPerMessage = 4 tokensPerName = -1 // If there's a name, the role is omitted } else { @@ -89,37 +96,38 @@ func CountTokenMessages(messages []model.Message, model string) int { tokenNum := 0 for _, message := range messages { tokenNum += tokensPerMessage - switch v := message.Content.(type) { - case string: - tokenNum += getTokenNum(tokenEncoder, v) - case []any: - for _, it := range v { - m := it.(map[string]any) - switch m["type"] { - case "text": - if textValue, ok := m["text"]; ok { - if textString, ok := textValue.(string); ok { - tokenNum += getTokenNum(tokenEncoder, textString) - } - } - case "image_url": - imageUrl, ok := m["image_url"].(map[string]any) - if ok { - url := imageUrl["url"].(string) - detail := "" - if imageUrl["detail"] != nil { - detail = imageUrl["detail"].(string) - } - imageTokens, err := countImageTokens(url, detail, model) - if err != nil { - logger.SysError("error counting image tokens: " + err.Error()) - } else { - tokenNum += imageTokens - } - } + contents := message.ParseContent() + for _, content := range contents { + switch content.Type { + case model.ContentTypeText: + tokenNum += getTokenNum(tokenEncoder, content.Text) + case model.ContentTypeImageURL: + imageTokens, err := countImageTokens( + content.ImageURL.Url, + content.ImageURL.Detail, + actualModel) + if err != nil { + logger.SysError("error counting image tokens: " + err.Error()) + } else { + tokenNum += imageTokens + } + case model.ContentTypeInputAudio: + audioData, err := base64.StdEncoding.DecodeString(content.InputAudio.Data) + if err != nil { + logger.SysError("error decoding audio data: " + err.Error()) + } + + tokens, err := helper.GetAudioTokens(ctx, + bytes.NewReader(audioData), + ratio.GetAudioPromptTokensPerSecond(actualModel)) + if err != nil { + logger.SysError("error counting audio tokens: " + err.Error()) + } else { + tokenNum += tokens } } } + tokenNum += getTokenNum(tokenEncoder, message.Role) if message.Name != nil { tokenNum += tokensPerName diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index d1720a99..14a23a51 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -3,6 +3,7 @@ package ratio import ( "encoding/json" "fmt" + "math" "strings" "github.com/songquanpeng/one-api/common/logger" @@ -22,65 +23,71 @@ const ( // 1 === ¥0.014 / 1k tokens var ModelRatio = map[string]float64{ // https://openai.com/pricing - "gpt-4": 15, - "gpt-4-0314": 15, - "gpt-4-0613": 15, - "gpt-4-32k": 30, - "gpt-4-32k-0314": 30, - "gpt-4-32k-0613": 30, - "gpt-4-1106-preview": 5, // $0.01 / 1K tokens - "gpt-4-0125-preview": 5, // $0.01 / 1K tokens - "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens - "gpt-4-turbo": 5, // $0.01 / 1K tokens - "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens - "gpt-4o": 2.5, // $0.005 / 1K tokens - "chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens - "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens - "gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens - "gpt-4o-2024-11-20": 1.25, // $0.0025 / 1K tokens - "gpt-4o-mini": 0.075, // $0.00015 / 1K tokens - "gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens - "gpt-4-vision-preview": 5, // $0.01 / 1K tokens - "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens - "gpt-3.5-turbo-0301": 0.75, - "gpt-3.5-turbo-0613": 0.75, - "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens - "gpt-3.5-turbo-16k-0613": 1.5, - "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens - "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens - "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens - "o1": 7.5, // $15.00 / 1M input tokens - "o1-2024-12-17": 7.5, - "o1-preview": 7.5, // $15.00 / 1M input tokens - "o1-preview-2024-09-12": 7.5, - "o1-mini": 1.5, // $3.00 / 1M input tokens - "o1-mini-2024-09-12": 1.5, - "davinci-002": 1, // $0.002 / 1K tokens - "babbage-002": 0.2, // $0.0004 / 1K tokens - "text-ada-001": 0.2, - "text-babbage-001": 0.25, - "text-curie-001": 1, - "text-davinci-002": 10, - "text-davinci-003": 10, - "text-davinci-edit-001": 10, - "code-davinci-edit-001": 10, - "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens - "tts-1": 7.5, // $0.015 / 1K characters - "tts-1-1106": 7.5, - "tts-1-hd": 15, // $0.030 / 1K characters - "tts-1-hd-1106": 15, - "davinci": 10, - "curie": 10, - "babbage": 10, - "ada": 10, - "text-embedding-ada-002": 0.05, - "text-embedding-3-small": 0.01, - "text-embedding-3-large": 0.065, - "text-search-ada-doc-001": 10, - "text-moderation-stable": 0.1, - "text-moderation-latest": 0.1, - "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image - "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image + "gpt-4": 15, + "gpt-4-0314": 15, + "gpt-4-0613": 15, + "gpt-4-32k": 30, + "gpt-4-32k-0314": 30, + "gpt-4-32k-0613": 30, + "gpt-4-1106-preview": 5, // $0.01 / 1K tokens + "gpt-4-0125-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo": 5, // $0.01 / 1K tokens + "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens + "gpt-4o": 2.5, // $0.005 / 1K tokens + "chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens + "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens + "gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens + "gpt-4o-2024-11-20": 1.25, // $0.0025 / 1K tokens + "gpt-4o-mini": 0.075, // $0.00015 / 1K tokens + "gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens + "gpt-4-vision-preview": 5, // $0.01 / 1K tokens + // Audio billing will mix text and audio tokens, the unit price is different. + // Here records the cost of text, the cost multiplier of audio + // relative to text is in AudioRatio + "gpt-4o-audio-preview": 1.25, // $0.0025 / 1K tokens + "gpt-4o-audio-preview-2024-12-17": 1.25, // $0.0025 / 1K tokens + "gpt-4o-audio-preview-2024-10-01": 1.25, // $0.0025 / 1K tokens + "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens + "gpt-3.5-turbo-0301": 0.75, + "gpt-3.5-turbo-0613": 0.75, + "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens + "gpt-3.5-turbo-16k-0613": 1.5, + "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens + "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens + "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens + "o1": 7.5, // $15.00 / 1M input tokens + "o1-2024-12-17": 7.5, + "o1-preview": 7.5, // $15.00 / 1M input tokens + "o1-preview-2024-09-12": 7.5, + "o1-mini": 1.5, // $3.00 / 1M input tokens + "o1-mini-2024-09-12": 1.5, + "davinci-002": 1, // $0.002 / 1K tokens + "babbage-002": 0.2, // $0.0004 / 1K tokens + "text-ada-001": 0.2, + "text-babbage-001": 0.25, + "text-curie-001": 1, + "text-davinci-002": 10, + "text-davinci-003": 10, + "text-davinci-edit-001": 10, + "code-davinci-edit-001": 10, + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "tts-1": 7.5, // $0.015 / 1K characters + "tts-1-1106": 7.5, + "tts-1-hd": 15, // $0.030 / 1K characters + "tts-1-hd-1106": 15, + "davinci": 10, + "curie": 10, + "babbage": 10, + "ada": 10, + "text-embedding-ada-002": 0.05, + "text-embedding-3-small": 0.01, + "text-embedding-3-large": 0.065, + "text-search-ada-doc-001": 10, + "text-moderation-stable": 0.1, + "text-moderation-latest": 0.1, + "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image + "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image // https://www.anthropic.com/api#pricing "claude-instant-1.2": 0.8 / 1000 * USD, "claude-2.0": 8.0 / 1000 * USD, @@ -254,7 +261,6 @@ var ModelRatio = map[string]float64{ "llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD, "llama3-groq-8b-8192-tool-use-preview": 0.19 / 1000000 * USD, "mixtral-8x7b-32768": 0.24 / 1000000 * USD, - // https://platform.lingyiwanwu.com/docs#-计费单元 "yi-34b-chat-0205": 2.5 / 1000 * RMB, "yi-34b-chat-200k": 12.0 / 1000 * RMB, @@ -333,6 +339,68 @@ var ModelRatio = map[string]float64{ "mistralai/mixtral-8x7b-instruct-v0.1": 0.300 * USD, } +// AudioRatio represents the price ratio between audio tokens and text tokens +var AudioRatio = map[string]float64{ + "gpt-4o-audio-preview": 16, + "gpt-4o-audio-preview-2024-12-17": 16, + "gpt-4o-audio-preview-2024-10-01": 40, +} + +// GetAudioPromptRatio returns the audio prompt ratio for the given model. +func GetAudioPromptRatio(actualModelName string) float64 { + var v float64 + if ratio, ok := AudioRatio[actualModelName]; ok { + v = ratio + } else { + v = 16 + } + + return v +} + +// AudioCompletionRatio is the completion ratio for audio models. +var AudioCompletionRatio = map[string]float64{ + "whisper-1": 0, + "gpt-4o-audio-preview": 2, + "gpt-4o-audio-preview-2024-12-17": 2, + "gpt-4o-audio-preview-2024-10-01": 2, +} + +// GetAudioCompletionRatio returns the completion ratio for audio models. +func GetAudioCompletionRatio(actualModelName string) float64 { + var v float64 + if ratio, ok := AudioCompletionRatio[actualModelName]; ok { + v = ratio + } else { + v = 2 + } + + return v +} + +// AudioTokensPerSecond is the number of audio tokens per second for each model. +var AudioPromptTokensPerSecond = map[string]float64{ + // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens + "whisper-1": 1000 / 20, + // gpt-4o-audio series processes 10 tokens per second + "gpt-4o-audio-preview": 10, + "gpt-4o-audio-preview-2024-12-17": 10, + "gpt-4o-audio-preview-2024-10-01": 10, +} + +// GetAudioPromptTokensPerSecond returns the number of audio tokens per second +// for the given model. +func GetAudioPromptTokensPerSecond(actualModelName string) int { + var v float64 + if tokensPerSecond, ok := AudioPromptTokensPerSecond[actualModelName]; ok { + v = tokensPerSecond + } else { + v = 10 + } + + return int(math.Ceil(v)) +} + var CompletionRatio = map[string]float64{ // aws llama3 "llama3-8b-8192(33)": 0.0006 / 0.0003, @@ -397,19 +465,21 @@ func GetModelRatio(name string, channelType int) float64 { if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet") } + model := fmt.Sprintf("%s(%d)", name, channelType) - if ratio, ok := ModelRatio[model]; ok { - return ratio - } - if ratio, ok := DefaultModelRatio[model]; ok { - return ratio - } - if ratio, ok := ModelRatio[name]; ok { - return ratio - } - if ratio, ok := DefaultModelRatio[name]; ok { - return ratio + + for _, targetName := range []string{model, name} { + for _, ratioMap := range []map[string]float64{ + ModelRatio, + DefaultModelRatio, + AudioRatio, + } { + if ratio, ok := ratioMap[targetName]; ok { + return ratio + } + } } + logger.SysError("model ratio not found: " + name) return 30 } @@ -432,18 +502,19 @@ func GetCompletionRatio(name string, channelType int) float64 { name = strings.TrimSuffix(name, "-internet") } model := fmt.Sprintf("%s(%d)", name, channelType) - if ratio, ok := CompletionRatio[model]; ok { - return ratio - } - if ratio, ok := DefaultCompletionRatio[model]; ok { - return ratio - } - if ratio, ok := CompletionRatio[name]; ok { - return ratio - } - if ratio, ok := DefaultCompletionRatio[name]; ok { - return ratio + + for _, targetName := range []string{model, name} { + for _, ratioMap := range []map[string]float64{ + CompletionRatio, + DefaultCompletionRatio, + AudioCompletionRatio, + } { + if ratio, ok := ratioMap[targetName]; ok { + return ratio + } + } } + if strings.HasPrefix(name, "gpt-3.5") { if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") { // https://openai.com/blog/new-embedding-models-and-api-updates diff --git a/relay/controller/audio.go b/relay/controller/audio.go index bc756f65..b90666d3 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -7,10 +7,8 @@ import ( "encoding/json" "fmt" "io" - "math" "mime/multipart" "net/http" - "os" "strings" "github.com/gin-gonic/gin" @@ -23,6 +21,7 @@ import ( "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/billing" + "github.com/songquanpeng/one-api/relay/billing/ratio" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" @@ -30,10 +29,6 @@ import ( "github.com/songquanpeng/one-api/relay/relaymode" ) -const ( - TokensPerSecond = 1000 / 20 // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens -) - type commonAudioRequest struct { File *multipart.FileHeader `form:"file" binding:"required"` } @@ -54,27 +49,13 @@ func countAudioTokens(c *gin.Context) (int, error) { if err != nil { return 0, errors.WithStack(err) } + defer reqFp.Close() - tmpFp, err := os.CreateTemp("", "audio-*") - if err != nil { - return 0, errors.WithStack(err) - } - defer os.Remove(tmpFp.Name()) + ctxMeta := meta.GetByContext(c) - _, err = io.Copy(tmpFp, reqFp) - if err != nil { - return 0, errors.WithStack(err) - } - if err = tmpFp.Close(); err != nil { - return 0, errors.WithStack(err) - } - - duration, err := helper.GetAudioDuration(c.Request.Context(), tmpFp.Name()) - if err != nil { - return 0, errors.WithStack(err) - } - - return int(math.Ceil(duration)) * TokensPerSecond, nil + return helper.GetAudioTokens(c.Request.Context(), + reqFp, + ratio.GetAudioPromptTokensPerSecond(ctxMeta.ActualModelName)) } func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 5f5fc90c..03d79b3d 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "github.com/songquanpeng/one-api/relay/constant/role" "math" "net/http" "strings" @@ -17,6 +16,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/constant/role" "github.com/songquanpeng/one-api/relay/controller/validator" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" @@ -42,10 +42,10 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener return textRequest, nil } -func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { +func getPromptTokens(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { case relaymode.ChatCompletions: - return openai.CountTokenMessages(textRequest.Messages, textRequest.Model) + return openai.CountTokenMessages(ctx, textRequest.Messages, textRequest.Model) case relaymode.Completions: return openai.CountTokenInput(textRequest.Prompt, textRequest.Model) case relaymode.Moderations: diff --git a/relay/controller/text.go b/relay/controller/text.go index 9a47c58b..203719f6 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -4,11 +4,11 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common/config" "io" "net/http" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor" @@ -43,7 +43,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio // pre-consume quota - promptTokens := getPromptTokens(textRequest, meta.Mode) + promptTokens := getPromptTokens(c.Request.Context(), textRequest, meta.Mode) meta.PromptTokens = promptTokens preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta) if bizErr != nil { diff --git a/relay/model/general.go b/relay/model/general.go index 288c07ff..5354694c 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -23,36 +23,37 @@ type StreamOptions struct { type GeneralOpenAIRequest struct { // https://platform.openai.com/docs/api-reference/chat/create - Messages []Message `json:"messages,omitempty"` - Model string `json:"model,omitempty"` - Store *bool `json:"store,omitempty"` - Metadata any `json:"metadata,omitempty"` - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` - LogitBias any `json:"logit_bias,omitempty"` - Logprobs *bool `json:"logprobs,omitempty"` - TopLogprobs *int `json:"top_logprobs,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` - N int `json:"n,omitempty"` - Modalities []string `json:"modalities,omitempty"` - Prediction any `json:"prediction,omitempty"` - Audio *Audio `json:"audio,omitempty"` - PresencePenalty *float64 `json:"presence_penalty,omitempty"` - ResponseFormat *ResponseFormat `json:"response_format,omitempty"` - Seed float64 `json:"seed,omitempty"` - ServiceTier *string `json:"service_tier,omitempty"` - Stop any `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - StreamOptions *StreamOptions `json:"stream_options,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Tools []Tool `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"` - User string `json:"user,omitempty"` - FunctionCall any `json:"function_call,omitempty"` - Functions any `json:"functions,omitempty"` + Messages []Message `json:"messages,omitempty"` + Model string `json:"model,omitempty"` + Store *bool `json:"store,omitempty"` + Metadata any `json:"metadata,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias any `json:"logit_bias,omitempty"` + Logprobs *bool `json:"logprobs,omitempty"` + TopLogprobs *int `json:"top_logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` + N int `json:"n,omitempty"` + // Modalities currently the model only programmatically allows modalities = [“text”, “audio”] + Modalities []string `json:"modalities,omitempty"` + Prediction any `json:"prediction,omitempty"` + Audio *Audio `json:"audio,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + Seed float64 `json:"seed,omitempty"` + ServiceTier *string `json:"service_tier,omitempty"` + Stop any `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"` + User string `json:"user,omitempty"` + FunctionCall any `json:"function_call,omitempty"` + Functions any `json:"functions,omitempty"` // https://platform.openai.com/docs/api-reference/embeddings/create Input any `json:"input,omitempty"` EncodingFormat string `json:"encoding_format,omitempty"` diff --git a/relay/model/message.go b/relay/model/message.go index b908f989..48ddb3ad 100644 --- a/relay/model/message.go +++ b/relay/model/message.go @@ -1,11 +1,26 @@ package model +import ( + "context" + + "github.com/songquanpeng/one-api/common/logger" +) + type Message struct { - Role string `json:"role,omitempty"` - Content any `json:"content,omitempty"` - Name *string `json:"name,omitempty"` - ToolCalls []Tool `json:"tool_calls,omitempty"` - ToolCallId string `json:"tool_call_id,omitempty"` + Role string `json:"role,omitempty"` + // Content is a string or a list of objects + Content any `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls []Tool `json:"tool_calls,omitempty"` + ToolCallId string `json:"tool_call_id,omitempty"` + Audio *messageAudio `json:"audio,omitempty"` +} + +type messageAudio struct { + Id string `json:"id"` + Data string `json:"data,omitempty"` + ExpiredAt int `json:"expired_at,omitempty"` + Transcript string `json:"transcript,omitempty"` } func (m Message) IsStringContent() bool { @@ -26,6 +41,7 @@ func (m Message) StringContent() string { if !ok { continue } + if contentMap["type"] == ContentTypeText { if subStr, ok := contentMap["text"].(string); ok { contentStr += subStr @@ -34,6 +50,7 @@ func (m Message) StringContent() string { } return contentStr } + return "" } @@ -47,6 +64,7 @@ func (m Message) ParseContent() []MessageContent { }) return contentList } + anyList, ok := m.Content.([]any) if ok { for _, contentItem := range anyList { @@ -71,8 +89,21 @@ func (m Message) ParseContent() []MessageContent { }, }) } + case ContentTypeInputAudio: + if subObj, ok := contentMap["input_audio"].(map[string]any); ok { + contentList = append(contentList, MessageContent{ + Type: ContentTypeInputAudio, + InputAudio: &InputAudio{ + Data: subObj["data"].(string), + Format: subObj["format"].(string), + }, + }) + } + default: + logger.Warnf(context.TODO(), "unknown content type: %s", contentMap["type"]) } } + return contentList } return nil @@ -84,7 +115,18 @@ type ImageURL struct { } type MessageContent struct { - Type string `json:"type,omitempty"` - Text string `json:"text"` - ImageURL *ImageURL `json:"image_url,omitempty"` + // Type should be one of the following: text/input_audio + Type string `json:"type,omitempty"` + Text string `json:"text"` + ImageURL *ImageURL `json:"image_url,omitempty"` + InputAudio *InputAudio `json:"input_audio,omitempty"` +} + +type InputAudio struct { + // Data is the base64 encoded audio data + Data string `json:"data" binding:"required"` + // Format is the audio format, should be one of the + // following: mp3/mp4/mpeg/mpga/m4a/wav/webm/pcm16. + // When stream=true, format should be pcm16 + Format string `json:"format"` } diff --git a/relay/model/misc.go b/relay/model/misc.go index 163bc398..ff3f061d 100644 --- a/relay/model/misc.go +++ b/relay/model/misc.go @@ -1,9 +1,13 @@ package model type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails usagePromptTokensDetails `gorm:"-" json:"prompt_tokens_details,omitempty"` + CompletionTokensDetails usageCompletionTokensDetails `gorm:"-" json:"completion_tokens_details,omitempty"` + ServiceTier string `gorm:"-" json:"service_tier,omitempty"` + SystemFingerprint string `gorm:"-" json:"system_fingerprint,omitempty"` } type Error struct { @@ -17,3 +21,18 @@ type ErrorWithStatusCode struct { Error StatusCode int `json:"status_code"` } + +type usagePromptTokensDetails struct { + CachedTokens int `json:"cached_tokens"` + AudioTokens int `json:"audio_tokens"` + TextTokens int `json:"text_tokens"` + ImageTokens int `json:"image_tokens"` +} + +type usageCompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens"` + AudioTokens int `json:"audio_tokens"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` + TextTokens int `json:"text_tokens"` +} From ca9aaaf07dec21a4537671b0a55e5d2ad57985e9 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Tue, 14 Jan 2025 13:37:00 +0000 Subject: [PATCH 3/6] fix: enhance token usage calculations and improve logging in OpenAI handler --- relay/adaptor/openai/main.go | 16 ++++++++++------ relay/model/misc.go | 18 +++++++++++------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 095a6adb..f986ed09 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -118,8 +118,10 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st StatusCode: resp.StatusCode, }, nil } + // Reset response body resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + logger.Debugf(c.Request.Context(), "handler response: %s", string(responseBody)) // We shouldn't set the header before we parse the response body, because the parse part may fail. // And then we will have to send an error response, but in this case, the header has already been set. @@ -148,19 +150,21 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, } - } else { - // Convert the more expensive audio tokens to uniformly priced text tokens - textResponse.Usage.PromptTokens = textResponse.CompletionTokensDetails.TextTokens + + } else if textResponse.PromptTokensDetails.AudioTokens+textResponse.CompletionTokensDetails.AudioTokens > 0 { + // Convert the more expensive audio tokens to uniformly priced text tokens. + // Note that when there are no audio tokens in prompt and completion, + // OpenAI will return empty PromptTokensDetails and CompletionTokensDetails, which can be misleading. + textResponse.Usage.PromptTokens = textResponse.PromptTokensDetails.TextTokens + int(math.Ceil( - float64(textResponse.CompletionTokensDetails.AudioTokens)* + float64(textResponse.PromptTokensDetails.AudioTokens)* ratio.GetAudioPromptRatio(modelName), )) textResponse.Usage.CompletionTokens = textResponse.CompletionTokensDetails.TextTokens + int(math.Ceil( float64(textResponse.CompletionTokensDetails.AudioTokens)* - ratio.GetAudioPromptRatio(modelName)* - ratio.GetAudioCompletionRatio(modelName), + ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName), )) + textResponse.Usage.TotalTokens = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens } diff --git a/relay/model/misc.go b/relay/model/misc.go index ff3f061d..62c3fe6f 100644 --- a/relay/model/misc.go +++ b/relay/model/misc.go @@ -1,10 +1,12 @@ package model type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - PromptTokensDetails usagePromptTokensDetails `gorm:"-" json:"prompt_tokens_details,omitempty"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + // PromptTokensDetails may be empty for some models + PromptTokensDetails usagePromptTokensDetails `gorm:"-" json:"prompt_tokens_details,omitempty"` + // CompletionTokensDetails may be empty for some models CompletionTokensDetails usageCompletionTokensDetails `gorm:"-" json:"completion_tokens_details,omitempty"` ServiceTier string `gorm:"-" json:"service_tier,omitempty"` SystemFingerprint string `gorm:"-" json:"system_fingerprint,omitempty"` @@ -25,8 +27,9 @@ type ErrorWithStatusCode struct { type usagePromptTokensDetails struct { CachedTokens int `json:"cached_tokens"` AudioTokens int `json:"audio_tokens"` - TextTokens int `json:"text_tokens"` - ImageTokens int `json:"image_tokens"` + // TextTokens could be zero for pure text chats + TextTokens int `json:"text_tokens"` + ImageTokens int `json:"image_tokens"` } type usageCompletionTokensDetails struct { @@ -34,5 +37,6 @@ type usageCompletionTokensDetails struct { AudioTokens int `json:"audio_tokens"` AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` RejectedPredictionTokens int `json:"rejected_prediction_tokens"` - TextTokens int `json:"text_tokens"` + // TextTokens could be zero for pure text chats + TextTokens int `json:"text_tokens"` } From 010bc72304f44b4e5edde5eff1a4857f72347858 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Sun, 26 Jan 2025 08:02:55 +0000 Subject: [PATCH 4/6] fix: whisper model billing - Refactor model name handling across multiple controllers to improve clarity and maintainability. - Enhance error logging and handling for better debugging and request processing robustness. - Update pricing models in accordance with new calculations, ensuring accuracy in the billing logic. --- common/ctxkey/key.go | 1 + relay/adaptor/proxy/adaptor.go | 1 - relay/billing/ratio/model.go | 7 ++++--- relay/controller/helper.go | 11 ----------- relay/controller/image.go | 11 ++++++----- relay/controller/text.go | 19 +++++++++++-------- relay/meta/relay_meta.go | 26 +++++++++++++++++++++++++- 7 files changed, 47 insertions(+), 29 deletions(-) diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 115558a5..75c6da51 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -21,4 +21,5 @@ const ( AvailableModels = "available_models" KeyRequestBody = "key_request_body" SystemPrompt = "system_prompt" + Meta = "meta" ) diff --git a/relay/adaptor/proxy/adaptor.go b/relay/adaptor/proxy/adaptor.go index 670c7628..06fddee0 100644 --- a/relay/adaptor/proxy/adaptor.go +++ b/relay/adaptor/proxy/adaptor.go @@ -60,7 +60,6 @@ func (a *Adaptor) GetChannelName() string { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { prefix := fmt.Sprintf("/v1/oneapi/proxy/%d", meta.ChannelId) return meta.BaseURL + strings.TrimPrefix(meta.RequestURLPath, prefix), nil - } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 14a23a51..d52de788 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -71,7 +71,7 @@ var ModelRatio = map[string]float64{ "text-davinci-003": 10, "text-davinci-edit-001": 10, "code-davinci-edit-001": 10, - "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "whisper-1": 15, "tts-1": 7.5, // $0.015 / 1K characters "tts-1-1106": 7.5, "tts-1-hd": 15, // $0.030 / 1K characters @@ -380,8 +380,9 @@ func GetAudioCompletionRatio(actualModelName string) float64 { // AudioTokensPerSecond is the number of audio tokens per second for each model. var AudioPromptTokensPerSecond = map[string]float64{ - // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens - "whisper-1": 1000 / 20, + // whisper 的 API 价格是 $0.0001/sec。one-api 的历史倍率为 15,对应 $0.03/kilo_tokens。 + // 那么换算后可得,每秒的 tokens 应该为 0.0001/0.03*1000 = 3.3333 + "whisper-1": 0.0001 / 0.03 * 1000, // gpt-4o-audio series processes 10 tokens per second "gpt-4o-audio-preview": 10, "gpt-4o-audio-preview-2024-12-17": 10, diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 03d79b3d..d8937224 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -129,17 +129,6 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M model.UpdateChannelUsedQuota(meta.ChannelId, quota) } -func getMappedModelName(modelName string, mapping map[string]string) (string, bool) { - if mapping == nil { - return modelName, false - } - mappedModelName := mapping[modelName] - if mappedModelName != "" { - return mappedModelName, true - } - return modelName, false -} - func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { if resp == nil { if meta.ChannelType == channeltype.AwsClaude { diff --git a/relay/controller/image.go b/relay/controller/image.go index 1b69d97d..581859f1 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -18,7 +18,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" - "github.com/songquanpeng/one-api/relay/meta" + metalib "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" ) @@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 { return 1 } -func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode { +func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *metalib.Meta) *relaymodel.ErrorWithStatusCode { // check prompt length if imageRequest.Prompt == "" { return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) @@ -104,7 +104,7 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() - meta := meta.GetByContext(c) + meta := metalib.GetByContext(c) imageRequest, err := getImageRequest(c, meta.Mode) if err != nil { logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) @@ -114,7 +114,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus // map model name var isModelMapped bool meta.OriginModelName = imageRequest.Model - imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping) + imageRequest.Model = meta.ActualModelName + isModelMapped = meta.OriginModelName != meta.ActualModelName meta.ActualModelName = imageRequest.Model // model validation @@ -130,7 +131,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus imageModel := imageRequest.Model // Convert the original image model - imageRequest.Model, _ = getMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName) + imageRequest.Model = metalib.GetMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName) c.Set("response_format", imageRequest.ResponseFormat) var requestBody io.Reader diff --git a/relay/controller/text.go b/relay/controller/text.go index 203719f6..69a51386 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -17,13 +17,13 @@ import ( "github.com/songquanpeng/one-api/relay/billing" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" - "github.com/songquanpeng/one-api/relay/meta" - "github.com/songquanpeng/one-api/relay/model" + metalib "github.com/songquanpeng/one-api/relay/meta" + relaymodel "github.com/songquanpeng/one-api/relay/model" ) -func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { +func RelayTextHelper(c *gin.Context) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() - meta := meta.GetByContext(c) + meta := metalib.GetByContext(c) // get & validate textRequest textRequest, err := getAndValidateTextRequest(c, meta.Mode) if err != nil { @@ -34,7 +34,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { // map model name meta.OriginModelName = textRequest.Model - textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) + textRequest.Model = meta.ActualModelName meta.ActualModelName = textRequest.Model // set system prompt if not empty systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt) @@ -86,9 +86,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { return nil } -func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { - if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan { - // no need to convert request for openai +func getRequestBody(c *gin.Context, meta *metalib.Meta, textRequest *relaymodel.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { + if !config.EnforceIncludeUsage && + meta.APIType == apitype.OpenAI && + meta.OriginModelName == meta.ActualModelName && + meta.ChannelType != channeltype.OpenAI && // openai also need to convert request + meta.ChannelType != channeltype.Baichuan { return c.Request.Body, nil } diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index bcbe1045..02b19504 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -1,12 +1,13 @@ package meta import ( + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/relaymode" - "strings" ) type Meta struct { @@ -33,6 +34,20 @@ type Meta struct { SystemPrompt string } +// GetMappedModelName returns the mapped model name and a bool indicating if the model name is mapped +func GetMappedModelName(modelName string, mapping map[string]string) string { + if mapping == nil { + return modelName + } + + mappedModelName := mapping[modelName] + if mappedModelName != "" { + return mappedModelName + } + + return modelName +} + func GetByContext(c *gin.Context) *Meta { meta := Meta{ Mode: relaymode.GetByPath(c.Request.URL.Path), @@ -44,6 +59,7 @@ func GetByContext(c *gin.Context) *Meta { Group: c.GetString(ctxkey.Group), ModelMapping: c.GetStringMapString(ctxkey.ModelMapping), OriginModelName: c.GetString(ctxkey.RequestModel), + ActualModelName: c.GetString(ctxkey.RequestModel), BaseURL: c.GetString(ctxkey.BaseURL), APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), RequestURLPath: c.Request.URL.String(), @@ -57,5 +73,13 @@ func GetByContext(c *gin.Context) *Meta { meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType] } meta.APIType = channeltype.ToAPIType(meta.ChannelType) + + meta.ActualModelName = GetMappedModelName(meta.OriginModelName, meta.ModelMapping) + + Set2Context(c, &meta) return &meta } + +func Set2Context(c *gin.Context, meta *Meta) { + c.Set(ctxkey.Meta, meta) +} From bcba9bf3a19754f17c34565dd530abae88e9a8db Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Sun, 26 Jan 2025 08:21:51 +0000 Subject: [PATCH 5/6] feat: only allow gpt-audio stream mode when EnforceIncludeUsage is true --- common/config/config.go | 1 + relay/adaptor/openai/adaptor.go | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/common/config/config.go b/common/config/config.go index 2eb894ef..a2a74139 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -161,4 +161,5 @@ var RelayProxy = env.String("RELAY_PROXY", "") var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) +// EnforceIncludeUsage is used to determine whether to include usage in the response var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false) diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 21966262..e688d5fa 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -97,10 +97,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G }(request.Messages) } - if request.Stream && strings.HasPrefix(request.Model, "gpt-4o-audio") { + if request.Stream && strings.HasPrefix(request.Model, "gpt-4o-audio") && !config.EnforceIncludeUsage { // TODO: Since it is not clear how to implement billing in stream mode, // it is temporarily not supported - return nil, errors.New("stream mode is not supported for gpt-4o-audio") + return nil, errors.New("set ENFORCE_INCLUDE_USAGE=true to enable stream mode for gpt-4o-audio") } return request, nil From b83e4002975815745b5ff9bcf77db00135e4129b Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Sun, 26 Jan 2025 12:17:31 +0000 Subject: [PATCH 6/6] fix: change GetAudioTokens to return float64 and update related functions --- common/helper/audio.go | 5 ++--- relay/adaptor/openai/adaptor.go | 1 + relay/adaptor/openai/token.go | 8 ++++++-- relay/billing/ratio/model.go | 5 ++--- relay/controller/audio.go | 5 +++-- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/common/helper/audio.go b/common/helper/audio.go index e5689904..e31afc44 100644 --- a/common/helper/audio.go +++ b/common/helper/audio.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "io" - "math" "os" "os/exec" "strconv" @@ -33,7 +32,7 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) { } // GetAudioTokens returns the number of tokens in an audio file. -func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond int) (int, error) { +func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond float64) (float64, error) { filename, err := SaveTmpFile("audio", audio) if err != nil { return 0, errors.Wrap(err, "failed to save audio to temporary file") @@ -45,7 +44,7 @@ func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond int) ( return 0, errors.Wrap(err, "failed to get audio tokens") } - return int(math.Ceil(duration)) * tokensPerSecond, nil + return duration * tokensPerSecond, nil } // GetAudioDuration returns the duration of an audio file in seconds. diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index e688d5fa..b54c71ee 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/doubao" "github.com/songquanpeng/one-api/relay/adaptor/minimax" diff --git a/relay/adaptor/openai/token.go b/relay/adaptor/openai/token.go index 1287e44b..4cfdd2b7 100644 --- a/relay/adaptor/openai/token.go +++ b/relay/adaptor/openai/token.go @@ -93,7 +93,9 @@ func CountTokenMessages(ctx context.Context, tokensPerMessage = 3 tokensPerName = 1 } + tokenNum := 0 + var totalAudioTokens float64 for _, message := range messages { tokenNum += tokensPerMessage contents := message.ParseContent() @@ -117,17 +119,19 @@ func CountTokenMessages(ctx context.Context, logger.SysError("error decoding audio data: " + err.Error()) } - tokens, err := helper.GetAudioTokens(ctx, + audioTokens, err := helper.GetAudioTokens(ctx, bytes.NewReader(audioData), ratio.GetAudioPromptTokensPerSecond(actualModel)) if err != nil { logger.SysError("error counting audio tokens: " + err.Error()) } else { - tokenNum += tokens + totalAudioTokens += audioTokens } } } + tokenNum += int(math.Ceil(totalAudioTokens)) + tokenNum += getTokenNum(tokenEncoder, message.Role) if message.Name != nil { tokenNum += tokensPerName diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index d52de788..c992260b 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -3,7 +3,6 @@ package ratio import ( "encoding/json" "fmt" - "math" "strings" "github.com/songquanpeng/one-api/common/logger" @@ -391,7 +390,7 @@ var AudioPromptTokensPerSecond = map[string]float64{ // GetAudioPromptTokensPerSecond returns the number of audio tokens per second // for the given model. -func GetAudioPromptTokensPerSecond(actualModelName string) int { +func GetAudioPromptTokensPerSecond(actualModelName string) float64 { var v float64 if tokensPerSecond, ok := AudioPromptTokensPerSecond[actualModelName]; ok { v = tokensPerSecond @@ -399,7 +398,7 @@ func GetAudioPromptTokensPerSecond(actualModelName string) int { v = 10 } - return int(math.Ceil(v)) + return v } var CompletionRatio = map[string]float64{ diff --git a/relay/controller/audio.go b/relay/controller/audio.go index b90666d3..32cc0d38 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "math" "mime/multipart" "net/http" "strings" @@ -33,7 +34,7 @@ type commonAudioRequest struct { File *multipart.FileHeader `form:"file" binding:"required"` } -func countAudioTokens(c *gin.Context) (int, error) { +func countAudioTokens(c *gin.Context) (float64, error) { body, err := common.GetRequestBody(c) if err != nil { return 0, errors.WithStack(err) @@ -101,7 +102,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError) } - preConsumedQuota = int64(float64(audioTokens) * ratio) + preConsumedQuota = int64(math.Ceil(audioTokens * ratio)) quota = preConsumedQuota default: return openai.ErrorWrapper(errors.New("unexpected_relay_mode"), "unexpected_relay_mode", http.StatusInternalServerError)