feat: support gpt-4o-audio

This commit is contained in:
Laisky.Cai
2025-01-14 06:38:07 +00:00
parent c6c8053ccc
commit 2fc6caaae5
14 changed files with 401 additions and 198 deletions

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"io"
"math"
"os"
"os/exec"
"strconv"
@@ -13,7 +14,11 @@ import (
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
func SaveTmpFile(filename string, data io.Reader) (string, error) {
f, err := os.CreateTemp(os.TempDir(), filename)
if data == nil {
return "", errors.New("data is nil")
}
f, err := os.CreateTemp("", "*-"+filename)
if err != nil {
return "", errors.Wrapf(err, "failed to create temporary file %s", filename)
}
@@ -27,6 +32,22 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
return f.Name(), nil
}
// GetAudioTokens returns the number of tokens in an audio file.
func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond int) (int, error) {
filename, err := SaveTmpFile("audio", audio)
if err != nil {
return 0, errors.Wrap(err, "failed to save audio to temporary file")
}
defer os.Remove(filename)
duration, err := GetAudioDuration(ctx, filename)
if err != nil {
return 0, errors.Wrap(err, "failed to get audio tokens")
}
return int(math.Ceil(duration)) * tokensPerSecond, nil
}
// GetAudioDuration returns the duration of an audio file in seconds.
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
@@ -36,5 +57,7 @@ func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
return 0, errors.Wrap(err, "failed to get audio duration")
}
// Actually gpt-4-audio calculates tokens with 0.1s precision,
// while whisper calculates tokens with 1s precision
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
}

View File

@@ -35,3 +35,21 @@ func TestGetAudioDuration(t *testing.T) {
require.Error(t, err)
})
}
func TestGetAudioTokens(t *testing.T) {
t.Run("should return correct tokens for a valid audio file", func(t *testing.T) {
// download test audio file
resp, err := http.Get("https://s3.laisky.com/uploads/2025/01/audio-sample.m4a")
require.NoError(t, err)
defer resp.Body.Close()
tokens, err := GetAudioTokens(context.Background(), resp.Body, 50)
require.NoError(t, err)
require.Equal(t, tokens, 200)
})
t.Run("should return an error for a non-existent file", func(t *testing.T) {
_, err := GetAudioTokens(context.Background(), nil, 1)
require.Error(t, err)
})
}