mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 07:43:41 +08:00 
			
		
		
		
	feat: batch update with laisky's one-api
This commit is contained in:
		
							
								
								
									
										62
									
								
								common/helper/audio.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								common/helper/audio.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,62 @@
 | 
			
		||||
package helper
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"io"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
 | 
			
		||||
func SaveTmpFile(filename string, data io.Reader) (string, error) {
 | 
			
		||||
	if data == nil {
 | 
			
		||||
		return "", errors.New("data is nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	f, err := os.CreateTemp("", "*-"+filename)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.Wrapf(err, "failed to create temporary file %s", filename)
 | 
			
		||||
	}
 | 
			
		||||
	defer f.Close()
 | 
			
		||||
 | 
			
		||||
	_, err = io.Copy(f, data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return f.Name(), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAudioTokens returns the number of tokens in an audio file.
 | 
			
		||||
func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond float64) (float64, error) {
 | 
			
		||||
	filename, err := SaveTmpFile("audio", audio)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, errors.Wrap(err, "failed to save audio to temporary file")
 | 
			
		||||
	}
 | 
			
		||||
	defer os.Remove(filename)
 | 
			
		||||
 | 
			
		||||
	duration, err := GetAudioDuration(ctx, filename)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, errors.Wrap(err, "failed to get audio tokens")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return duration * tokensPerSecond, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAudioDuration returns the duration of an audio file in seconds.
 | 
			
		||||
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
 | 
			
		||||
	// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
 | 
			
		||||
	c := exec.CommandContext(ctx, "/usr/bin/ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
 | 
			
		||||
	output, err := c.Output()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, errors.Wrap(err, "failed to get audio duration")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Actually gpt-4-audio calculates tokens with 0.1s precision,
 | 
			
		||||
	// while whisper calculates tokens with 1s precision
 | 
			
		||||
	return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										68
									
								
								common/helper/audio_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								common/helper/audio_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,68 @@
 | 
			
		||||
package helper
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestGetAudioDuration(t *testing.T) {
 | 
			
		||||
	// skip if there is no ffmpeg installed
 | 
			
		||||
	_, err := exec.LookPath("ffmpeg")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Skip("ffmpeg not installed, skipping test")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Run("should return correct duration for a valid audio file", func(t *testing.T) {
 | 
			
		||||
		tmpFile, err := os.CreateTemp("", "test_audio*.mp3")
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
		defer os.Remove(tmpFile.Name())
 | 
			
		||||
 | 
			
		||||
		// download test audio file
 | 
			
		||||
		resp, err := http.Get("https://s3.laisky.com/uploads/2025/01/audio-sample.m4a")
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
		defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
		_, err = io.Copy(tmpFile, resp.Body)
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
		require.NoError(t, tmpFile.Close())
 | 
			
		||||
 | 
			
		||||
		duration, err := GetAudioDuration(context.Background(), tmpFile.Name())
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
		require.Equal(t, duration, 3.904)
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("should return an error for a non-existent file", func(t *testing.T) {
 | 
			
		||||
		_, err := GetAudioDuration(context.Background(), "non_existent_file.mp3")
 | 
			
		||||
		require.Error(t, err)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGetAudioTokens(t *testing.T) {
 | 
			
		||||
	// skip if there is no ffmpeg installed
 | 
			
		||||
	_, err := exec.LookPath("ffmpeg")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Skip("ffmpeg not installed, skipping test")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Run("should return correct tokens for a valid audio file", func(t *testing.T) {
 | 
			
		||||
		// download test audio file
 | 
			
		||||
		resp, err := http.Get("https://s3.laisky.com/uploads/2025/01/audio-sample.m4a")
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
		defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
		tokens, err := GetAudioTokens(context.Background(), resp.Body, 50)
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
		require.Equal(t, tokens, 195.2)
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("should return an error for a non-existent file", func(t *testing.T) {
 | 
			
		||||
		_, err := GetAudioTokens(context.Background(), nil, 1)
 | 
			
		||||
		require.Error(t, err)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
@@ -6,13 +6,13 @@ import (
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -32,6 +32,14 @@ func OpenBrowser(url string) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RespondError sends a JSON response with a success status and an error message.
 | 
			
		||||
func RespondError(c *gin.Context, err error) {
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": false,
 | 
			
		||||
		"message": err.Error(),
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetIp() (ip string) {
 | 
			
		||||
	ips, err := net.InterfaceAddrs()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// GetTimestamp get current timestamp in seconds
 | 
			
		||||
func GetTimestamp() int64 {
 | 
			
		||||
	return time.Now().Unix()
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	return aiProxyLibraryRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -67,7 +67,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	return ConvertRequest(c, *request)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -72,7 +72,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -39,7 +39,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -109,7 +109,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -19,7 +19,7 @@ type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConvertImageRequest implements adaptor.Adaptor.
 | 
			
		||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	return nil, errors.New("not implemented")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -15,7 +15,7 @@ import (
 | 
			
		||||
type Adaptor struct{}
 | 
			
		||||
 | 
			
		||||
// ConvertImageRequest implements adaptor.Adaptor.
 | 
			
		||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	return nil, errors.New("not implemented")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	return ConvertRequest(*request), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	return convertedRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -66,7 +66,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -13,7 +13,7 @@ type Adaptor interface {
 | 
			
		||||
	GetRequestURL(meta *meta.Meta) (string, error)
 | 
			
		||||
	SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
 | 
			
		||||
	ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
 | 
			
		||||
	ConvertImageRequest(request *model.ImageRequest) (any, error)
 | 
			
		||||
	ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error)
 | 
			
		||||
	DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
 | 
			
		||||
	DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
 | 
			
		||||
	GetModelList() []string
 | 
			
		||||
 
 | 
			
		||||
@@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -39,16 +39,24 @@ func (a *Adaptor) Init(meta *meta.Meta) {
 | 
			
		||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
 | 
			
		||||
	switch meta.ChannelType {
 | 
			
		||||
	case channeltype.Azure:
 | 
			
		||||
		defaultVersion := meta.Config.APIVersion
 | 
			
		||||
 | 
			
		||||
		// https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/reasoning?tabs=python#api--feature-support
 | 
			
		||||
		if strings.HasPrefix(meta.ActualModelName, "o1") ||
 | 
			
		||||
			strings.HasPrefix(meta.ActualModelName, "o3") {
 | 
			
		||||
			defaultVersion = "2024-12-01-preview"
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if meta.Mode == relaymode.ImagesGenerations {
 | 
			
		||||
			// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
 | 
			
		||||
			// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
 | 
			
		||||
			fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion)
 | 
			
		||||
			fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, defaultVersion)
 | 
			
		||||
			return fullRequestURL, nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
 | 
			
		||||
		requestURL := strings.Split(meta.RequestURLPath, "?")[0]
 | 
			
		||||
		requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.Config.APIVersion)
 | 
			
		||||
		requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, defaultVersion)
 | 
			
		||||
		task := strings.TrimPrefix(requestURL, "/v1/")
 | 
			
		||||
		model_ := meta.ActualModelName
 | 
			
		||||
		model_ = strings.Replace(model_, ".", "", -1)
 | 
			
		||||
@@ -160,7 +168,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	return request, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
@@ -191,6 +199,8 @@ func (a *Adaptor) DoResponse(c *gin.Context,
 | 
			
		||||
		switch meta.Mode {
 | 
			
		||||
		case relaymode.ImagesGenerations:
 | 
			
		||||
			err, _ = ImageHandler(c, resp)
 | 
			
		||||
		case relaymode.ImagesEdits:
 | 
			
		||||
			err, _ = ImagesEditsHandler(c, resp)
 | 
			
		||||
		default:
 | 
			
		||||
			err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -7,11 +7,10 @@ var ModelList = []string{
 | 
			
		||||
	"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
 | 
			
		||||
	"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
 | 
			
		||||
	"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
 | 
			
		||||
	"gpt-4o", "gpt-4o-2024-05-13",
 | 
			
		||||
	"gpt-4o-2024-08-06",
 | 
			
		||||
	"gpt-4o-2024-11-20",
 | 
			
		||||
	"chatgpt-4o-latest",
 | 
			
		||||
	"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20", "chatgpt-4o-latest",
 | 
			
		||||
	"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
 | 
			
		||||
	"gpt-4o-mini-audio-preview", "gpt-4o-mini-audio-preview-2024-12-17",
 | 
			
		||||
	"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-12-17", "gpt-4o-audio-preview-2024-10-01",
 | 
			
		||||
	"gpt-4-vision-preview",
 | 
			
		||||
	"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
 | 
			
		||||
	"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
 | 
			
		||||
 
 | 
			
		||||
@@ -3,12 +3,30 @@ package openai
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ImagesEditsHandler just copy response body to client
 | 
			
		||||
//
 | 
			
		||||
// https://platform.openai.com/docs/api-reference/images/createEdit
 | 
			
		||||
func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	for k, v := range resp.Header {
 | 
			
		||||
		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if _, err := io.Copy(c.Writer, resp.Body); err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	var imageResponse ImageResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"io"
 | 
			
		||||
	"math"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
@@ -13,6 +14,7 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/conv"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/billing/ratio"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/relaymode"
 | 
			
		||||
)
 | 
			
		||||
@@ -23,144 +25,300 @@ const (
 | 
			
		||||
	dataPrefixLength = len(dataPrefix)
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// StreamHandler processes streaming responses from OpenAI API
 | 
			
		||||
// It handles incremental content delivery and accumulates the final response text
 | 
			
		||||
// Returns error (if any), accumulated response text, and token usage information
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
 | 
			
		||||
	// Initialize accumulators for the response
 | 
			
		||||
	responseText := ""
 | 
			
		||||
	reasoningText := ""
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
	var usage *model.Usage
 | 
			
		||||
 | 
			
		||||
	// Set up scanner for reading the stream line by line
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	buffer := make([]byte, 256*1024) // 256KB buffer for large messages
 | 
			
		||||
	scanner.Buffer(buffer, len(buffer))
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	// Set response headers for SSE
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
 | 
			
		||||
	doneRendered := false
 | 
			
		||||
 | 
			
		||||
	// Process each line from the stream
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < dataPrefixLength { // ignore blank line or wrong format
 | 
			
		||||
			continue
 | 
			
		||||
		data := NormalizeDataLine(scanner.Text())
 | 
			
		||||
 | 
			
		||||
		// logger.Debugf(c.Request.Context(), "stream response: %s", data)
 | 
			
		||||
 | 
			
		||||
		// Skip lines that don't match expected format
 | 
			
		||||
		if len(data) < dataPrefixLength {
 | 
			
		||||
			continue // Ignore blank line or wrong format
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Verify line starts with expected prefix
 | 
			
		||||
		if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Check for stream termination
 | 
			
		||||
		if strings.HasPrefix(data[dataPrefixLength:], done) {
 | 
			
		||||
			render.StringData(c, data)
 | 
			
		||||
			doneRendered = true
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Process based on relay mode
 | 
			
		||||
		switch relayMode {
 | 
			
		||||
		case relaymode.ChatCompletions:
 | 
			
		||||
			var streamResponse ChatCompletionsStreamResponse
 | 
			
		||||
 | 
			
		||||
			// Parse the JSON response
 | 
			
		||||
			err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				render.StringData(c, data) // if error happened, pass the data to client
 | 
			
		||||
				continue                   // just ignore the error
 | 
			
		||||
				logger.Errorf(c.Request.Context(), "unmarshalling stream data %q got %+v", data, err)
 | 
			
		||||
				render.StringData(c, data) // Pass raw data to client if parsing fails
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Skip empty choices (Azure specific behavior)
 | 
			
		||||
			if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
 | 
			
		||||
				// but for empty choice and no usage, we should not pass it to client, this is for azure
 | 
			
		||||
				continue // just ignore empty choice
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			render.StringData(c, data)
 | 
			
		||||
 | 
			
		||||
			// Process each choice in the response
 | 
			
		||||
			for _, choice := range streamResponse.Choices {
 | 
			
		||||
				if choice.Delta.Reasoning != nil {
 | 
			
		||||
					reasoningText += *choice.Delta.Reasoning
 | 
			
		||||
				}
 | 
			
		||||
				if choice.Delta.ReasoningContent != nil {
 | 
			
		||||
					reasoningText += *choice.Delta.ReasoningContent
 | 
			
		||||
				// Extract reasoning content from different possible fields
 | 
			
		||||
				currentReasoningChunk := extractReasoningContent(&choice.Delta)
 | 
			
		||||
 | 
			
		||||
				// Update accumulated reasoning text
 | 
			
		||||
				if currentReasoningChunk != "" {
 | 
			
		||||
					reasoningText += currentReasoningChunk
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// Set the reasoning content in the format requested by client
 | 
			
		||||
				choice.Delta.SetReasoningContent(c.Query("reasoning_format"), currentReasoningChunk)
 | 
			
		||||
 | 
			
		||||
				// Accumulate response content
 | 
			
		||||
				responseText += conv.AsString(choice.Delta.Content)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Send the processed data to the client
 | 
			
		||||
			render.StringData(c, data)
 | 
			
		||||
 | 
			
		||||
			// Update usage information if available
 | 
			
		||||
			if streamResponse.Usage != nil {
 | 
			
		||||
				usage = streamResponse.Usage
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		case relaymode.Completions:
 | 
			
		||||
			// Send the data immediately for Completions mode
 | 
			
		||||
			render.StringData(c, data)
 | 
			
		||||
 | 
			
		||||
			var streamResponse CompletionsStreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Accumulate text from all choices
 | 
			
		||||
			for _, choice := range streamResponse.Choices {
 | 
			
		||||
				responseText += choice.Text
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check for scanner errors
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Ensure stream termination is sent to client
 | 
			
		||||
	if !doneRendered {
 | 
			
		||||
		render.Done(c)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
	// Clean up resources
 | 
			
		||||
	if err := resp.Body.Close(); err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Return the complete response text (reasoning + content) and usage
 | 
			
		||||
	return nil, reasoningText + responseText, usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Helper function to extract reasoning content from message delta
 | 
			
		||||
func extractReasoningContent(delta *model.Message) string {
 | 
			
		||||
	content := ""
 | 
			
		||||
 | 
			
		||||
	// Extract reasoning from different possible fields
 | 
			
		||||
	if delta.Reasoning != nil {
 | 
			
		||||
		content += *delta.Reasoning
 | 
			
		||||
		delta.Reasoning = nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if delta.ReasoningContent != nil {
 | 
			
		||||
		content += *delta.ReasoningContent
 | 
			
		||||
		delta.ReasoningContent = nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return content
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Handler processes non-streaming responses from OpenAI API
 | 
			
		||||
// Returns error (if any) and token usage information
 | 
			
		||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	var textResponse SlimTextResponse
 | 
			
		||||
	// Read the entire response body
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 | 
			
		||||
	// Close the original response body
 | 
			
		||||
	if err = resp.Body.Close(); err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &textResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 | 
			
		||||
	// Parse the response JSON
 | 
			
		||||
	var textResponse SlimTextResponse
 | 
			
		||||
	if err = json.Unmarshal(responseBody, &textResponse); err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check for API errors
 | 
			
		||||
	if textResponse.Error.Type != "" {
 | 
			
		||||
		return &model.ErrorWithStatusCode{
 | 
			
		||||
			Error:      textResponse.Error,
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	// Reset response body
 | 
			
		||||
	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
 | 
			
		||||
	// We shouldn't set the header before we parse the response body, because the parse part may fail.
 | 
			
		||||
	// And then we will have to send an error response, but in this case, the header has already been set.
 | 
			
		||||
	// So the HTTPClient will be confused by the response.
 | 
			
		||||
	// For example, Postman will report error, and we cannot check the response at all.
 | 
			
		||||
	for k, v := range resp.Header {
 | 
			
		||||
		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
	// Process reasoning content in each choice
 | 
			
		||||
	for _, msg := range textResponse.Choices {
 | 
			
		||||
		reasoningContent := processReasoningContent(&msg)
 | 
			
		||||
 | 
			
		||||
		// Set reasoning in requested format if content exists
 | 
			
		||||
		if reasoningContent != "" {
 | 
			
		||||
			msg.SetReasoningContent(c.Query("reasoning_format"), reasoningContent)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Reset response body for forwarding to client
 | 
			
		||||
	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
	logger.Debugf(c.Request.Context(), "handler response: %s", string(responseBody))
 | 
			
		||||
 | 
			
		||||
	// Forward all response headers (not just first value of each)
 | 
			
		||||
	for k, values := range resp.Header {
 | 
			
		||||
		for _, v := range values {
 | 
			
		||||
			c.Writer.Header().Add(k, v)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Set response status and copy body to client
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
	if _, err = io.Copy(c.Writer, resp.Body); err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 | 
			
		||||
	// Close the reset body
 | 
			
		||||
	if err = resp.Body.Close(); err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if textResponse.Usage.TotalTokens == 0 ||
 | 
			
		||||
		(textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) {
 | 
			
		||||
	// Calculate token usage if not provided by API
 | 
			
		||||
	calculateTokenUsage(&textResponse, promptTokens, modelName)
 | 
			
		||||
 | 
			
		||||
	return nil, &textResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// processReasoningContent is a helper function to extract and process reasoning content from the message
 | 
			
		||||
func processReasoningContent(msg *TextResponseChoice) string {
 | 
			
		||||
	var reasoningContent string
 | 
			
		||||
 | 
			
		||||
	// Check different locations for reasoning content
 | 
			
		||||
	switch {
 | 
			
		||||
	case msg.Reasoning != nil:
 | 
			
		||||
		reasoningContent = *msg.Reasoning
 | 
			
		||||
		msg.Reasoning = nil
 | 
			
		||||
	case msg.ReasoningContent != nil:
 | 
			
		||||
		reasoningContent = *msg.ReasoningContent
 | 
			
		||||
		msg.ReasoningContent = nil
 | 
			
		||||
	case msg.Message.Reasoning != nil:
 | 
			
		||||
		reasoningContent = *msg.Message.Reasoning
 | 
			
		||||
		msg.Message.Reasoning = nil
 | 
			
		||||
	case msg.Message.ReasoningContent != nil:
 | 
			
		||||
		reasoningContent = *msg.Message.ReasoningContent
 | 
			
		||||
		msg.Message.ReasoningContent = nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return reasoningContent
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Helper function to calculate token usage
 | 
			
		||||
func calculateTokenUsage(response *SlimTextResponse, promptTokens int, modelName string) {
 | 
			
		||||
	// Calculate tokens if not provided by the API
 | 
			
		||||
	if response.Usage.TotalTokens == 0 ||
 | 
			
		||||
		(response.Usage.PromptTokens == 0 && response.Usage.CompletionTokens == 0) {
 | 
			
		||||
 | 
			
		||||
		completionTokens := 0
 | 
			
		||||
		for _, choice := range textResponse.Choices {
 | 
			
		||||
		for _, choice := range response.Choices {
 | 
			
		||||
			// Count content tokens
 | 
			
		||||
			completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
 | 
			
		||||
 | 
			
		||||
			// Count reasoning tokens in all possible locations
 | 
			
		||||
			if choice.Message.Reasoning != nil {
 | 
			
		||||
				completionTokens += CountToken(*choice.Message.Reasoning)
 | 
			
		||||
			}
 | 
			
		||||
			if choice.Message.ReasoningContent != nil {
 | 
			
		||||
				completionTokens += CountToken(*choice.Message.ReasoningContent)
 | 
			
		||||
			}
 | 
			
		||||
			if choice.Reasoning != nil {
 | 
			
		||||
				completionTokens += CountToken(*choice.Reasoning)
 | 
			
		||||
			}
 | 
			
		||||
			if choice.ReasoningContent != nil {
 | 
			
		||||
				completionTokens += CountToken(*choice.ReasoningContent)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		textResponse.Usage = model.Usage{
 | 
			
		||||
 | 
			
		||||
		// Set usage values
 | 
			
		||||
		response.Usage = model.Usage{
 | 
			
		||||
			PromptTokens:     promptTokens,
 | 
			
		||||
			CompletionTokens: completionTokens,
 | 
			
		||||
			TotalTokens:      promptTokens + completionTokens,
 | 
			
		||||
		}
 | 
			
		||||
	} else if hasAudioTokens(response) {
 | 
			
		||||
		// Handle audio tokens conversion
 | 
			
		||||
		calculateAudioTokens(response, modelName)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Helper function to check if response has audio tokens
 | 
			
		||||
func hasAudioTokens(response *SlimTextResponse) bool {
 | 
			
		||||
	return (response.PromptTokensDetails != nil && response.PromptTokensDetails.AudioTokens > 0) ||
 | 
			
		||||
		(response.CompletionTokensDetails != nil && response.CompletionTokensDetails.AudioTokens > 0)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Helper function to calculate audio token usage
 | 
			
		||||
func calculateAudioTokens(response *SlimTextResponse, modelName string) {
 | 
			
		||||
	// Convert audio tokens for prompt
 | 
			
		||||
	if response.PromptTokensDetails != nil {
 | 
			
		||||
		response.Usage.PromptTokens = response.PromptTokensDetails.TextTokens +
 | 
			
		||||
			int(math.Ceil(
 | 
			
		||||
				float64(response.PromptTokensDetails.AudioTokens)*
 | 
			
		||||
					ratio.GetAudioPromptRatio(modelName),
 | 
			
		||||
			))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, &textResponse.Usage
 | 
			
		||||
	// Convert audio tokens for completion
 | 
			
		||||
	if response.CompletionTokensDetails != nil {
 | 
			
		||||
		response.Usage.CompletionTokens = response.CompletionTokensDetails.TextTokens +
 | 
			
		||||
			int(math.Ceil(
 | 
			
		||||
				float64(response.CompletionTokensDetails.AudioTokens)*
 | 
			
		||||
					ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName),
 | 
			
		||||
			))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Calculate total tokens
 | 
			
		||||
	response.Usage.TotalTokens = response.Usage.PromptTokens + response.Usage.CompletionTokens
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,10 @@
 | 
			
		||||
package openai
 | 
			
		||||
 | 
			
		||||
import "github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
import (
 | 
			
		||||
	"mime/multipart"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TextContent struct {
 | 
			
		||||
	Type string `json:"type,omitempty"`
 | 
			
		||||
@@ -71,6 +75,24 @@ type TextToSpeechRequest struct {
 | 
			
		||||
	ResponseFormat string  `json:"response_format"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AudioTranscriptionRequest struct {
 | 
			
		||||
	File                 *multipart.FileHeader `form:"file" binding:"required"`
 | 
			
		||||
	Model                string                `form:"model" binding:"required"`
 | 
			
		||||
	Language             string                `form:"language"`
 | 
			
		||||
	Prompt               string                `form:"prompt"`
 | 
			
		||||
	ReponseFormat        string                `form:"response_format" binding:"oneof=json text srt verbose_json vtt"`
 | 
			
		||||
	Temperature          float64               `form:"temperature"`
 | 
			
		||||
	TimestampGranularity []string              `form:"timestamp_granularity"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AudioTranslationRequest struct {
 | 
			
		||||
	File           *multipart.FileHeader `form:"file" binding:"required"`
 | 
			
		||||
	Model          string                `form:"model" binding:"required"`
 | 
			
		||||
	Prompt         string                `form:"prompt"`
 | 
			
		||||
	ResponseFormat string                `form:"response_format" binding:"oneof=json text srt verbose_json vtt"`
 | 
			
		||||
	Temperature    float64               `form:"temperature"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type UsageOrResponseText struct {
 | 
			
		||||
	*model.Usage
 | 
			
		||||
	ResponseText string
 | 
			
		||||
@@ -110,12 +132,14 @@ type EmbeddingResponse struct {
 | 
			
		||||
	model.Usage `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ImageData represents an image in the response
 | 
			
		||||
type ImageData struct {
 | 
			
		||||
	Url           string `json:"url,omitempty"`
 | 
			
		||||
	B64Json       string `json:"b64_json,omitempty"`
 | 
			
		||||
	RevisedPrompt string `json:"revised_prompt,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ImageResponse represents the response structure for image generations
 | 
			
		||||
type ImageResponse struct {
 | 
			
		||||
	Created int64       `json:"created"`
 | 
			
		||||
	Data    []ImageData `json:"data"`
 | 
			
		||||
 
 | 
			
		||||
@@ -1,16 +1,20 @@
 | 
			
		||||
package openai
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/pkoukk/tiktoken-go"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/image"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/billing/ratio"
 | 
			
		||||
	billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
@@ -73,8 +77,10 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
 | 
			
		||||
	return len(tokenEncoder.Encode(text, nil, nil))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CountTokenMessages(messages []model.Message, model string) int {
 | 
			
		||||
	tokenEncoder := getTokenEncoder(model)
 | 
			
		||||
// CountTokenMessages counts the number of tokens in a list of messages.
 | 
			
		||||
func CountTokenMessages(ctx context.Context,
 | 
			
		||||
	messages []model.Message, actualModel string) int {
 | 
			
		||||
	tokenEncoder := getTokenEncoder(actualModel)
 | 
			
		||||
	// Reference:
 | 
			
		||||
	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
 | 
			
		||||
	// https://github.com/pkoukk/tiktoken-go/issues/6
 | 
			
		||||
@@ -82,47 +88,54 @@ func CountTokenMessages(messages []model.Message, model string) int {
 | 
			
		||||
	// Every message follows <|start|>{role/name}\n{content}<|end|>\n
 | 
			
		||||
	var tokensPerMessage int
 | 
			
		||||
	var tokensPerName int
 | 
			
		||||
	if model == "gpt-3.5-turbo-0301" {
 | 
			
		||||
	if actualModel == "gpt-3.5-turbo-0301" {
 | 
			
		||||
		tokensPerMessage = 4
 | 
			
		||||
		tokensPerName = -1 // If there's a name, the role is omitted
 | 
			
		||||
	} else {
 | 
			
		||||
		tokensPerMessage = 3
 | 
			
		||||
		tokensPerName = 1
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tokenNum := 0
 | 
			
		||||
	var totalAudioTokens float64
 | 
			
		||||
	for _, message := range messages {
 | 
			
		||||
		tokenNum += tokensPerMessage
 | 
			
		||||
		switch v := message.Content.(type) {
 | 
			
		||||
		case string:
 | 
			
		||||
			tokenNum += getTokenNum(tokenEncoder, v)
 | 
			
		||||
		case []any:
 | 
			
		||||
			for _, it := range v {
 | 
			
		||||
				m := it.(map[string]any)
 | 
			
		||||
				switch m["type"] {
 | 
			
		||||
				case "text":
 | 
			
		||||
					if textValue, ok := m["text"]; ok {
 | 
			
		||||
						if textString, ok := textValue.(string); ok {
 | 
			
		||||
							tokenNum += getTokenNum(tokenEncoder, textString)
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				case "image_url":
 | 
			
		||||
					imageUrl, ok := m["image_url"].(map[string]any)
 | 
			
		||||
					if ok {
 | 
			
		||||
						url := imageUrl["url"].(string)
 | 
			
		||||
						detail := ""
 | 
			
		||||
						if imageUrl["detail"] != nil {
 | 
			
		||||
							detail = imageUrl["detail"].(string)
 | 
			
		||||
						}
 | 
			
		||||
						imageTokens, err := countImageTokens(url, detail, model)
 | 
			
		||||
						if err != nil {
 | 
			
		||||
							logger.SysError("error counting image tokens: " + err.Error())
 | 
			
		||||
						} else {
 | 
			
		||||
							tokenNum += imageTokens
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
		contents := message.ParseContent()
 | 
			
		||||
		for _, content := range contents {
 | 
			
		||||
			switch content.Type {
 | 
			
		||||
			case model.ContentTypeText:
 | 
			
		||||
				if content.Text != nil {
 | 
			
		||||
					tokenNum += getTokenNum(tokenEncoder, *content.Text)
 | 
			
		||||
				}
 | 
			
		||||
			case model.ContentTypeImageURL:
 | 
			
		||||
				imageTokens, err := countImageTokens(
 | 
			
		||||
					content.ImageURL.Url,
 | 
			
		||||
					content.ImageURL.Detail,
 | 
			
		||||
					actualModel)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.SysError("error counting image tokens: " + err.Error())
 | 
			
		||||
				} else {
 | 
			
		||||
					tokenNum += imageTokens
 | 
			
		||||
				}
 | 
			
		||||
			case model.ContentTypeInputAudio:
 | 
			
		||||
				audioData, err := base64.StdEncoding.DecodeString(content.InputAudio.Data)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.SysError("error decoding audio data: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				audioTokens, err := helper.GetAudioTokens(ctx,
 | 
			
		||||
					bytes.NewReader(audioData),
 | 
			
		||||
					ratio.GetAudioPromptTokensPerSecond(actualModel))
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.SysError("error counting audio tokens: " + err.Error())
 | 
			
		||||
				} else {
 | 
			
		||||
					totalAudioTokens += audioTokens
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		tokenNum += int(math.Ceil(totalAudioTokens))
 | 
			
		||||
 | 
			
		||||
		tokenNum += getTokenNum(tokenEncoder, message.Role)
 | 
			
		||||
		if message.Name != nil {
 | 
			
		||||
			tokenNum += tokensPerName
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package openai
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
@@ -21,3 +22,11 @@ func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatus
 | 
			
		||||
		StatusCode: statusCode,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NormalizeDataLine(data string) string {
 | 
			
		||||
	if strings.HasPrefix(data, "data:") {
 | 
			
		||||
		content := strings.TrimLeft(data[len("data:"):], " ")
 | 
			
		||||
		return "data: " + content
 | 
			
		||||
	}
 | 
			
		||||
	return data
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -36,7 +36,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	return ConvertRequest(*request), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -80,7 +80,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	return nil, errors.Errorf("not implement")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -23,7 +23,7 @@ type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConvertImageRequest implements adaptor.Adaptor.
 | 
			
		||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	return DrawImageRequest{
 | 
			
		||||
		Input: ImageInput{
 | 
			
		||||
			Steps:           25,
 | 
			
		||||
 
 | 
			
		||||
@@ -69,7 +69,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	return convertedRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -105,7 +105,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -80,7 +80,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -43,10 +43,10 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
 | 
			
		||||
	return textRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
 | 
			
		||||
func getPromptTokens(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case relaymode.ChatCompletions:
 | 
			
		||||
		return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
 | 
			
		||||
		return openai.CountTokenMessages(ctx, textRequest.Messages, textRequest.Model)
 | 
			
		||||
	case relaymode.Completions:
 | 
			
		||||
		return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
 | 
			
		||||
	case relaymode.Moderations:
 | 
			
		||||
 
 | 
			
		||||
@@ -157,7 +157,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
		channeltype.Ali,
 | 
			
		||||
		channeltype.Replicate,
 | 
			
		||||
		channeltype.Baidu:
 | 
			
		||||
		finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
 | 
			
		||||
		finalRequest, err := adaptor.ConvertImageRequest(c, imageRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -45,7 +45,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
 | 
			
		||||
	groupRatio := billingratio.GetGroupRatio(meta.Group)
 | 
			
		||||
	ratio := modelRatio * groupRatio
 | 
			
		||||
	// pre-consume quota
 | 
			
		||||
	promptTokens := getPromptTokens(textRequest, meta.Mode)
 | 
			
		||||
	promptTokens := getPromptTokens(c.Request.Context(), textRequest, meta.Mode)
 | 
			
		||||
	meta.PromptTokens = promptTokens
 | 
			
		||||
	preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta)
 | 
			
		||||
	if bizErr != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -13,4 +13,5 @@ const (
 | 
			
		||||
	AudioTranslation
 | 
			
		||||
	// Proxy is a special relay mode for proxying requests to custom upstream
 | 
			
		||||
	Proxy
 | 
			
		||||
	ImagesEdits
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user