mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 09:16:36 +08:00
feat: batch update with laisky's one-api
This commit is contained in:
parent
761ee32d19
commit
b2d6aa783b
62
common/helper/audio.go
Normal file
62
common/helper/audio.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
package helper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
|
||||||
|
func SaveTmpFile(filename string, data io.Reader) (string, error) {
|
||||||
|
if data == nil {
|
||||||
|
return "", errors.New("data is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.CreateTemp("", "*-"+filename)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrapf(err, "failed to create temporary file %s", filename)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(f, data)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.Name(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAudioTokens returns the number of tokens in an audio file.
|
||||||
|
func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond float64) (float64, error) {
|
||||||
|
filename, err := SaveTmpFile("audio", audio)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.Wrap(err, "failed to save audio to temporary file")
|
||||||
|
}
|
||||||
|
defer os.Remove(filename)
|
||||||
|
|
||||||
|
duration, err := GetAudioDuration(ctx, filename)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.Wrap(err, "failed to get audio tokens")
|
||||||
|
}
|
||||||
|
|
||||||
|
return duration * tokensPerSecond, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAudioDuration returns the duration of an audio file in seconds.
|
||||||
|
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
|
||||||
|
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
|
||||||
|
c := exec.CommandContext(ctx, "/usr/bin/ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
|
||||||
|
output, err := c.Output()
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Actually gpt-4-audio calculates tokens with 0.1s precision,
|
||||||
|
// while whisper calculates tokens with 1s precision
|
||||||
|
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
|
||||||
|
}
|
68
common/helper/audio_test.go
Normal file
68
common/helper/audio_test.go
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
package helper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetAudioDuration(t *testing.T) {
|
||||||
|
// skip if there is no ffmpeg installed
|
||||||
|
_, err := exec.LookPath("ffmpeg")
|
||||||
|
if err != nil {
|
||||||
|
t.Skip("ffmpeg not installed, skipping test")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("should return correct duration for a valid audio file", func(t *testing.T) {
|
||||||
|
tmpFile, err := os.CreateTemp("", "test_audio*.mp3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.Remove(tmpFile.Name())
|
||||||
|
|
||||||
|
// download test audio file
|
||||||
|
resp, err := http.Get("https://s3.laisky.com/uploads/2025/01/audio-sample.m4a")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(tmpFile, resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, tmpFile.Close())
|
||||||
|
|
||||||
|
duration, err := GetAudioDuration(context.Background(), tmpFile.Name())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, duration, 3.904)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should return an error for a non-existent file", func(t *testing.T) {
|
||||||
|
_, err := GetAudioDuration(context.Background(), "non_existent_file.mp3")
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAudioTokens(t *testing.T) {
|
||||||
|
// skip if there is no ffmpeg installed
|
||||||
|
_, err := exec.LookPath("ffmpeg")
|
||||||
|
if err != nil {
|
||||||
|
t.Skip("ffmpeg not installed, skipping test")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("should return correct tokens for a valid audio file", func(t *testing.T) {
|
||||||
|
// download test audio file
|
||||||
|
resp, err := http.Get("https://s3.laisky.com/uploads/2025/01/audio-sample.m4a")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
tokens, err := GetAudioTokens(context.Background(), resp.Body, 50)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tokens, 195.2)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should return an error for a non-existent file", func(t *testing.T) {
|
||||||
|
_, err := GetAudioTokens(context.Background(), nil, 1)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
@ -6,13 +6,13 @@ import (
|
|||||||
"html/template"
|
"html/template"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/random"
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -32,6 +32,14 @@ func OpenBrowser(url string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RespondError sends a JSON response with a success status and an error message.
|
||||||
|
func RespondError(c *gin.Context, err error) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func GetIp() (ip string) {
|
func GetIp() (ip string) {
|
||||||
ips, err := net.InterfaceAddrs()
|
ips, err := net.InterfaceAddrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GetTimestamp get current timestamp in seconds
|
||||||
func GetTimestamp() int64 {
|
func GetTimestamp() int64 {
|
||||||
return time.Now().Unix()
|
return time.Now().Unix()
|
||||||
}
|
}
|
||||||
|
@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return aiProxyLibraryRequest, nil
|
return aiProxyLibraryRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -67,7 +67,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return ConvertRequest(c, *request)
|
return ConvertRequest(c, *request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -72,7 +72,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -109,7 +109,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,7 @@ type Adaptor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ConvertImageRequest implements adaptor.Adaptor.
|
// ConvertImageRequest implements adaptor.Adaptor.
|
||||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ import (
|
|||||||
type Adaptor struct{}
|
type Adaptor struct{}
|
||||||
|
|
||||||
// ConvertImageRequest implements adaptor.Adaptor.
|
// ConvertImageRequest implements adaptor.Adaptor.
|
||||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return ConvertRequest(*request), nil
|
return ConvertRequest(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return convertedRequest, nil
|
return convertedRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -66,7 +66,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,7 @@ type Adaptor interface {
|
|||||||
GetRequestURL(meta *meta.Meta) (string, error)
|
GetRequestURL(meta *meta.Meta) (string, error)
|
||||||
SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
|
SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
|
||||||
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
||||||
ConvertImageRequest(request *model.ImageRequest) (any, error)
|
ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error)
|
||||||
DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
|
DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
|
||||||
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
||||||
GetModelList() []string
|
GetModelList() []string
|
||||||
|
@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -39,16 +39,24 @@ func (a *Adaptor) Init(meta *meta.Meta) {
|
|||||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||||
switch meta.ChannelType {
|
switch meta.ChannelType {
|
||||||
case channeltype.Azure:
|
case channeltype.Azure:
|
||||||
|
defaultVersion := meta.Config.APIVersion
|
||||||
|
|
||||||
|
// https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/reasoning?tabs=python#api--feature-support
|
||||||
|
if strings.HasPrefix(meta.ActualModelName, "o1") ||
|
||||||
|
strings.HasPrefix(meta.ActualModelName, "o3") {
|
||||||
|
defaultVersion = "2024-12-01-preview"
|
||||||
|
}
|
||||||
|
|
||||||
if meta.Mode == relaymode.ImagesGenerations {
|
if meta.Mode == relaymode.ImagesGenerations {
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
||||||
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
|
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
|
||||||
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion)
|
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, defaultVersion)
|
||||||
return fullRequestURL, nil
|
return fullRequestURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||||
requestURL := strings.Split(meta.RequestURLPath, "?")[0]
|
requestURL := strings.Split(meta.RequestURLPath, "?")[0]
|
||||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.Config.APIVersion)
|
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, defaultVersion)
|
||||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||||
model_ := meta.ActualModelName
|
model_ := meta.ActualModelName
|
||||||
model_ = strings.Replace(model_, ".", "", -1)
|
model_ = strings.Replace(model_, ".", "", -1)
|
||||||
@ -160,7 +168,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
@ -191,6 +199,8 @@ func (a *Adaptor) DoResponse(c *gin.Context,
|
|||||||
switch meta.Mode {
|
switch meta.Mode {
|
||||||
case relaymode.ImagesGenerations:
|
case relaymode.ImagesGenerations:
|
||||||
err, _ = ImageHandler(c, resp)
|
err, _ = ImageHandler(c, resp)
|
||||||
|
case relaymode.ImagesEdits:
|
||||||
|
err, _ = ImagesEditsHandler(c, resp)
|
||||||
default:
|
default:
|
||||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||||
}
|
}
|
||||||
|
@ -7,11 +7,10 @@ var ModelList = []string{
|
|||||||
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
||||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||||
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||||
"gpt-4o", "gpt-4o-2024-05-13",
|
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20", "chatgpt-4o-latest",
|
||||||
"gpt-4o-2024-08-06",
|
|
||||||
"gpt-4o-2024-11-20",
|
|
||||||
"chatgpt-4o-latest",
|
|
||||||
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||||
|
"gpt-4o-mini-audio-preview", "gpt-4o-mini-audio-preview-2024-12-17",
|
||||||
|
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-12-17", "gpt-4o-audio-preview-2024-10-01",
|
||||||
"gpt-4-vision-preview",
|
"gpt-4-vision-preview",
|
||||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||||
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
|
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
|
||||||
|
@ -3,12 +3,30 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ImagesEditsHandler just copy response body to client
|
||||||
|
//
|
||||||
|
// https://platform.openai.com/docs/api-reference/images/createEdit
|
||||||
|
func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
c.Writer.Header().Set(k, v[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||||
|
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
var imageResponse ImageResponse
|
var imageResponse ImageResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -13,6 +14,7 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/common/conv"
|
"github.com/songquanpeng/one-api/common/conv"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/common/render"
|
"github.com/songquanpeng/one-api/common/render"
|
||||||
|
"github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||||
)
|
)
|
||||||
@ -23,144 +25,300 @@ const (
|
|||||||
dataPrefixLength = len(dataPrefix)
|
dataPrefixLength = len(dataPrefix)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// StreamHandler processes streaming responses from OpenAI API
|
||||||
|
// It handles incremental content delivery and accumulates the final response text
|
||||||
|
// Returns error (if any), accumulated response text, and token usage information
|
||||||
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
|
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
|
||||||
|
// Initialize accumulators for the response
|
||||||
responseText := ""
|
responseText := ""
|
||||||
reasoningText := ""
|
reasoningText := ""
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(bufio.ScanLines)
|
|
||||||
var usage *model.Usage
|
var usage *model.Usage
|
||||||
|
|
||||||
|
// Set up scanner for reading the stream line by line
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
buffer := make([]byte, 256*1024) // 256KB buffer for large messages
|
||||||
|
scanner.Buffer(buffer, len(buffer))
|
||||||
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
|
// Set response headers for SSE
|
||||||
common.SetEventStreamHeaders(c)
|
common.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
doneRendered := false
|
doneRendered := false
|
||||||
|
|
||||||
|
// Process each line from the stream
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := NormalizeDataLine(scanner.Text())
|
||||||
if len(data) < dataPrefixLength { // ignore blank line or wrong format
|
|
||||||
continue
|
// logger.Debugf(c.Request.Context(), "stream response: %s", data)
|
||||||
|
|
||||||
|
// Skip lines that don't match expected format
|
||||||
|
if len(data) < dataPrefixLength {
|
||||||
|
continue // Ignore blank line or wrong format
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Verify line starts with expected prefix
|
||||||
if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
|
if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for stream termination
|
||||||
if strings.HasPrefix(data[dataPrefixLength:], done) {
|
if strings.HasPrefix(data[dataPrefixLength:], done) {
|
||||||
render.StringData(c, data)
|
render.StringData(c, data)
|
||||||
doneRendered = true
|
doneRendered = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process based on relay mode
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relaymode.ChatCompletions:
|
case relaymode.ChatCompletions:
|
||||||
var streamResponse ChatCompletionsStreamResponse
|
var streamResponse ChatCompletionsStreamResponse
|
||||||
|
|
||||||
|
// Parse the JSON response
|
||||||
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
logger.Errorf(c.Request.Context(), "unmarshalling stream data %q got %+v", data, err)
|
||||||
render.StringData(c, data) // if error happened, pass the data to client
|
render.StringData(c, data) // Pass raw data to client if parsing fails
|
||||||
continue // just ignore the error
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip empty choices (Azure specific behavior)
|
||||||
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
|
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
|
||||||
// but for empty choice and no usage, we should not pass it to client, this is for azure
|
continue
|
||||||
continue // just ignore empty choice
|
|
||||||
}
|
}
|
||||||
render.StringData(c, data)
|
|
||||||
|
// Process each choice in the response
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
if choice.Delta.Reasoning != nil {
|
// Extract reasoning content from different possible fields
|
||||||
reasoningText += *choice.Delta.Reasoning
|
currentReasoningChunk := extractReasoningContent(&choice.Delta)
|
||||||
}
|
|
||||||
if choice.Delta.ReasoningContent != nil {
|
// Update accumulated reasoning text
|
||||||
reasoningText += *choice.Delta.ReasoningContent
|
if currentReasoningChunk != "" {
|
||||||
|
reasoningText += currentReasoningChunk
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the reasoning content in the format requested by client
|
||||||
|
choice.Delta.SetReasoningContent(c.Query("reasoning_format"), currentReasoningChunk)
|
||||||
|
|
||||||
|
// Accumulate response content
|
||||||
responseText += conv.AsString(choice.Delta.Content)
|
responseText += conv.AsString(choice.Delta.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send the processed data to the client
|
||||||
|
render.StringData(c, data)
|
||||||
|
|
||||||
|
// Update usage information if available
|
||||||
if streamResponse.Usage != nil {
|
if streamResponse.Usage != nil {
|
||||||
usage = streamResponse.Usage
|
usage = streamResponse.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
case relaymode.Completions:
|
case relaymode.Completions:
|
||||||
|
// Send the data immediately for Completions mode
|
||||||
render.StringData(c, data)
|
render.StringData(c, data)
|
||||||
|
|
||||||
var streamResponse CompletionsStreamResponse
|
var streamResponse CompletionsStreamResponse
|
||||||
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Accumulate text from all choices
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseText += choice.Text
|
responseText += choice.Text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for scanner errors
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
logger.SysError("error reading stream: " + err.Error())
|
logger.SysError("error reading stream: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure stream termination is sent to client
|
||||||
if !doneRendered {
|
if !doneRendered {
|
||||||
render.Done(c)
|
render.Done(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := resp.Body.Close()
|
// Clean up resources
|
||||||
if err != nil {
|
if err := resp.Body.Close(); err != nil {
|
||||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
|
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the complete response text (reasoning + content) and usage
|
||||||
return nil, reasoningText + responseText, usage
|
return nil, reasoningText + responseText, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to extract reasoning content from message delta
|
||||||
|
func extractReasoningContent(delta *model.Message) string {
|
||||||
|
content := ""
|
||||||
|
|
||||||
|
// Extract reasoning from different possible fields
|
||||||
|
if delta.Reasoning != nil {
|
||||||
|
content += *delta.Reasoning
|
||||||
|
delta.Reasoning = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if delta.ReasoningContent != nil {
|
||||||
|
content += *delta.ReasoningContent
|
||||||
|
delta.ReasoningContent = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler processes non-streaming responses from OpenAI API
|
||||||
|
// Returns error (if any) and token usage information
|
||||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
var textResponse SlimTextResponse
|
// Read the entire response body
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
// Close the original response body
|
||||||
|
if err = resp.Body.Close(); err != nil {
|
||||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
|
||||||
if err != nil {
|
// Parse the response JSON
|
||||||
|
var textResponse SlimTextResponse
|
||||||
|
if err = json.Unmarshal(responseBody, &textResponse); err != nil {
|
||||||
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for API errors
|
||||||
if textResponse.Error.Type != "" {
|
if textResponse.Error.Type != "" {
|
||||||
return &model.ErrorWithStatusCode{
|
return &model.ErrorWithStatusCode{
|
||||||
Error: textResponse.Error,
|
Error: textResponse.Error,
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
// Reset response body
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
|
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
// Process reasoning content in each choice
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
for _, msg := range textResponse.Choices {
|
||||||
// So the HTTPClient will be confused by the response.
|
reasoningContent := processReasoningContent(&msg)
|
||||||
// For example, Postman will report error, and we cannot check the response at all.
|
|
||||||
for k, v := range resp.Header {
|
// Set reasoning in requested format if content exists
|
||||||
c.Writer.Header().Set(k, v[0])
|
if reasoningContent != "" {
|
||||||
|
msg.SetReasoningContent(c.Query("reasoning_format"), reasoningContent)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset response body for forwarding to client
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
logger.Debugf(c.Request.Context(), "handler response: %s", string(responseBody))
|
||||||
|
|
||||||
|
// Forward all response headers (not just first value of each)
|
||||||
|
for k, values := range resp.Header {
|
||||||
|
for _, v := range values {
|
||||||
|
c.Writer.Header().Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set response status and copy body to client
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
if _, err = io.Copy(c.Writer, resp.Body); err != nil {
|
||||||
if err != nil {
|
|
||||||
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
// Close the reset body
|
||||||
|
if err = resp.Body.Close(); err != nil {
|
||||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if textResponse.Usage.TotalTokens == 0 ||
|
// Calculate token usage if not provided by API
|
||||||
(textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) {
|
calculateTokenUsage(&textResponse, promptTokens, modelName)
|
||||||
|
|
||||||
|
return nil, &textResponse.Usage
|
||||||
|
}
|
||||||
|
|
||||||
|
// processReasoningContent is a helper function to extract and process reasoning content from the message
|
||||||
|
func processReasoningContent(msg *TextResponseChoice) string {
|
||||||
|
var reasoningContent string
|
||||||
|
|
||||||
|
// Check different locations for reasoning content
|
||||||
|
switch {
|
||||||
|
case msg.Reasoning != nil:
|
||||||
|
reasoningContent = *msg.Reasoning
|
||||||
|
msg.Reasoning = nil
|
||||||
|
case msg.ReasoningContent != nil:
|
||||||
|
reasoningContent = *msg.ReasoningContent
|
||||||
|
msg.ReasoningContent = nil
|
||||||
|
case msg.Message.Reasoning != nil:
|
||||||
|
reasoningContent = *msg.Message.Reasoning
|
||||||
|
msg.Message.Reasoning = nil
|
||||||
|
case msg.Message.ReasoningContent != nil:
|
||||||
|
reasoningContent = *msg.Message.ReasoningContent
|
||||||
|
msg.Message.ReasoningContent = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return reasoningContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to calculate token usage
|
||||||
|
func calculateTokenUsage(response *SlimTextResponse, promptTokens int, modelName string) {
|
||||||
|
// Calculate tokens if not provided by the API
|
||||||
|
if response.Usage.TotalTokens == 0 ||
|
||||||
|
(response.Usage.PromptTokens == 0 && response.Usage.CompletionTokens == 0) {
|
||||||
|
|
||||||
completionTokens := 0
|
completionTokens := 0
|
||||||
for _, choice := range textResponse.Choices {
|
for _, choice := range response.Choices {
|
||||||
|
// Count content tokens
|
||||||
completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
|
completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
|
||||||
|
|
||||||
|
// Count reasoning tokens in all possible locations
|
||||||
if choice.Message.Reasoning != nil {
|
if choice.Message.Reasoning != nil {
|
||||||
completionTokens += CountToken(*choice.Message.Reasoning)
|
completionTokens += CountToken(*choice.Message.Reasoning)
|
||||||
}
|
}
|
||||||
|
if choice.Message.ReasoningContent != nil {
|
||||||
|
completionTokens += CountToken(*choice.Message.ReasoningContent)
|
||||||
|
}
|
||||||
|
if choice.Reasoning != nil {
|
||||||
|
completionTokens += CountToken(*choice.Reasoning)
|
||||||
|
}
|
||||||
if choice.ReasoningContent != nil {
|
if choice.ReasoningContent != nil {
|
||||||
completionTokens += CountToken(*choice.ReasoningContent)
|
completionTokens += CountToken(*choice.ReasoningContent)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
textResponse.Usage = model.Usage{
|
|
||||||
|
// Set usage values
|
||||||
|
response.Usage = model.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
TotalTokens: promptTokens + completionTokens,
|
TotalTokens: promptTokens + completionTokens,
|
||||||
}
|
}
|
||||||
|
} else if hasAudioTokens(response) {
|
||||||
|
// Handle audio tokens conversion
|
||||||
|
calculateAudioTokens(response, modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to check if response has audio tokens
|
||||||
|
func hasAudioTokens(response *SlimTextResponse) bool {
|
||||||
|
return (response.PromptTokensDetails != nil && response.PromptTokensDetails.AudioTokens > 0) ||
|
||||||
|
(response.CompletionTokensDetails != nil && response.CompletionTokensDetails.AudioTokens > 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to calculate audio token usage
|
||||||
|
func calculateAudioTokens(response *SlimTextResponse, modelName string) {
|
||||||
|
// Convert audio tokens for prompt
|
||||||
|
if response.PromptTokensDetails != nil {
|
||||||
|
response.Usage.PromptTokens = response.PromptTokensDetails.TextTokens +
|
||||||
|
int(math.Ceil(
|
||||||
|
float64(response.PromptTokensDetails.AudioTokens)*
|
||||||
|
ratio.GetAudioPromptRatio(modelName),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, &textResponse.Usage
|
// Convert audio tokens for completion
|
||||||
|
if response.CompletionTokensDetails != nil {
|
||||||
|
response.Usage.CompletionTokens = response.CompletionTokensDetails.TextTokens +
|
||||||
|
int(math.Ceil(
|
||||||
|
float64(response.CompletionTokensDetails.AudioTokens)*
|
||||||
|
ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate total tokens
|
||||||
|
response.Usage.TotalTokens = response.Usage.PromptTokens + response.Usage.CompletionTokens
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import "github.com/songquanpeng/one-api/relay/model"
|
import (
|
||||||
|
"mime/multipart"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
type TextContent struct {
|
type TextContent struct {
|
||||||
Type string `json:"type,omitempty"`
|
Type string `json:"type,omitempty"`
|
||||||
@ -71,6 +75,24 @@ type TextToSpeechRequest struct {
|
|||||||
ResponseFormat string `json:"response_format"`
|
ResponseFormat string `json:"response_format"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AudioTranscriptionRequest struct {
|
||||||
|
File *multipart.FileHeader `form:"file" binding:"required"`
|
||||||
|
Model string `form:"model" binding:"required"`
|
||||||
|
Language string `form:"language"`
|
||||||
|
Prompt string `form:"prompt"`
|
||||||
|
ReponseFormat string `form:"response_format" binding:"oneof=json text srt verbose_json vtt"`
|
||||||
|
Temperature float64 `form:"temperature"`
|
||||||
|
TimestampGranularity []string `form:"timestamp_granularity"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AudioTranslationRequest struct {
|
||||||
|
File *multipart.FileHeader `form:"file" binding:"required"`
|
||||||
|
Model string `form:"model" binding:"required"`
|
||||||
|
Prompt string `form:"prompt"`
|
||||||
|
ResponseFormat string `form:"response_format" binding:"oneof=json text srt verbose_json vtt"`
|
||||||
|
Temperature float64 `form:"temperature"`
|
||||||
|
}
|
||||||
|
|
||||||
type UsageOrResponseText struct {
|
type UsageOrResponseText struct {
|
||||||
*model.Usage
|
*model.Usage
|
||||||
ResponseText string
|
ResponseText string
|
||||||
@ -110,12 +132,14 @@ type EmbeddingResponse struct {
|
|||||||
model.Usage `json:"usage"`
|
model.Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImageData represents an image in the response
|
||||||
type ImageData struct {
|
type ImageData struct {
|
||||||
Url string `json:"url,omitempty"`
|
Url string `json:"url,omitempty"`
|
||||||
B64Json string `json:"b64_json,omitempty"`
|
B64Json string `json:"b64_json,omitempty"`
|
||||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImageResponse represents the response structure for image generations
|
||||||
type ImageResponse struct {
|
type ImageResponse struct {
|
||||||
Created int64 `json:"created"`
|
Created int64 `json:"created"`
|
||||||
Data []ImageData `json:"data"`
|
Data []ImageData `json:"data"`
|
||||||
|
@ -1,16 +1,20 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/pkoukk/tiktoken-go"
|
"github.com/pkoukk/tiktoken-go"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/image"
|
"github.com/songquanpeng/one-api/common/image"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
@ -73,8 +77,10 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
|||||||
return len(tokenEncoder.Encode(text, nil, nil))
|
return len(tokenEncoder.Encode(text, nil, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTokenMessages(messages []model.Message, model string) int {
|
// CountTokenMessages counts the number of tokens in a list of messages.
|
||||||
tokenEncoder := getTokenEncoder(model)
|
func CountTokenMessages(ctx context.Context,
|
||||||
|
messages []model.Message, actualModel string) int {
|
||||||
|
tokenEncoder := getTokenEncoder(actualModel)
|
||||||
// Reference:
|
// Reference:
|
||||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||||
@ -82,47 +88,54 @@ func CountTokenMessages(messages []model.Message, model string) int {
|
|||||||
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
var tokensPerMessage int
|
var tokensPerMessage int
|
||||||
var tokensPerName int
|
var tokensPerName int
|
||||||
if model == "gpt-3.5-turbo-0301" {
|
if actualModel == "gpt-3.5-turbo-0301" {
|
||||||
tokensPerMessage = 4
|
tokensPerMessage = 4
|
||||||
tokensPerName = -1 // If there's a name, the role is omitted
|
tokensPerName = -1 // If there's a name, the role is omitted
|
||||||
} else {
|
} else {
|
||||||
tokensPerMessage = 3
|
tokensPerMessage = 3
|
||||||
tokensPerName = 1
|
tokensPerName = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenNum := 0
|
tokenNum := 0
|
||||||
|
var totalAudioTokens float64
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
tokenNum += tokensPerMessage
|
tokenNum += tokensPerMessage
|
||||||
switch v := message.Content.(type) {
|
contents := message.ParseContent()
|
||||||
case string:
|
for _, content := range contents {
|
||||||
tokenNum += getTokenNum(tokenEncoder, v)
|
switch content.Type {
|
||||||
case []any:
|
case model.ContentTypeText:
|
||||||
for _, it := range v {
|
if content.Text != nil {
|
||||||
m := it.(map[string]any)
|
tokenNum += getTokenNum(tokenEncoder, *content.Text)
|
||||||
switch m["type"] {
|
}
|
||||||
case "text":
|
case model.ContentTypeImageURL:
|
||||||
if textValue, ok := m["text"]; ok {
|
imageTokens, err := countImageTokens(
|
||||||
if textString, ok := textValue.(string); ok {
|
content.ImageURL.Url,
|
||||||
tokenNum += getTokenNum(tokenEncoder, textString)
|
content.ImageURL.Detail,
|
||||||
}
|
actualModel)
|
||||||
}
|
if err != nil {
|
||||||
case "image_url":
|
logger.SysError("error counting image tokens: " + err.Error())
|
||||||
imageUrl, ok := m["image_url"].(map[string]any)
|
} else {
|
||||||
if ok {
|
tokenNum += imageTokens
|
||||||
url := imageUrl["url"].(string)
|
}
|
||||||
detail := ""
|
case model.ContentTypeInputAudio:
|
||||||
if imageUrl["detail"] != nil {
|
audioData, err := base64.StdEncoding.DecodeString(content.InputAudio.Data)
|
||||||
detail = imageUrl["detail"].(string)
|
if err != nil {
|
||||||
}
|
logger.SysError("error decoding audio data: " + err.Error())
|
||||||
imageTokens, err := countImageTokens(url, detail, model)
|
}
|
||||||
if err != nil {
|
|
||||||
logger.SysError("error counting image tokens: " + err.Error())
|
audioTokens, err := helper.GetAudioTokens(ctx,
|
||||||
} else {
|
bytes.NewReader(audioData),
|
||||||
tokenNum += imageTokens
|
ratio.GetAudioPromptTokensPerSecond(actualModel))
|
||||||
}
|
if err != nil {
|
||||||
}
|
logger.SysError("error counting audio tokens: " + err.Error())
|
||||||
|
} else {
|
||||||
|
totalAudioTokens += audioTokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tokenNum += int(math.Ceil(totalAudioTokens))
|
||||||
|
|
||||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||||
if message.Name != nil {
|
if message.Name != nil {
|
||||||
tokenNum += tokensPerName
|
tokenNum += tokensPerName
|
||||||
|
@ -3,6 +3,7 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
@ -21,3 +22,11 @@ func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatus
|
|||||||
StatusCode: statusCode,
|
StatusCode: statusCode,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NormalizeDataLine(data string) string {
|
||||||
|
if strings.HasPrefix(data, "data:") {
|
||||||
|
content := strings.TrimLeft(data[len("data:"):], " ")
|
||||||
|
return "data: " + content
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
@ -36,7 +36,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return ConvertRequest(*request), nil
|
return ConvertRequest(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -80,7 +80,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
return nil, errors.Errorf("not implement")
|
return nil, errors.Errorf("not implement")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ type Adaptor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ConvertImageRequest implements adaptor.Adaptor.
|
// ConvertImageRequest implements adaptor.Adaptor.
|
||||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
return DrawImageRequest{
|
return DrawImageRequest{
|
||||||
Input: ImageInput{
|
Input: ImageInput{
|
||||||
Steps: 25,
|
Steps: 25,
|
||||||
|
@ -69,7 +69,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return convertedRequest, nil
|
return convertedRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -105,7 +105,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -80,7 +80,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -43,10 +43,10 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
|
|||||||
return textRequest, nil
|
return textRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
func getPromptTokens(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relaymode.ChatCompletions:
|
case relaymode.ChatCompletions:
|
||||||
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
|
return openai.CountTokenMessages(ctx, textRequest.Messages, textRequest.Model)
|
||||||
case relaymode.Completions:
|
case relaymode.Completions:
|
||||||
return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||||
case relaymode.Moderations:
|
case relaymode.Moderations:
|
||||||
|
@ -157,7 +157,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
channeltype.Ali,
|
channeltype.Ali,
|
||||||
channeltype.Replicate,
|
channeltype.Replicate,
|
||||||
channeltype.Baidu:
|
channeltype.Baidu:
|
||||||
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
finalRequest, err := adaptor.ConvertImageRequest(c, imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
@ -45,7 +45,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
|||||||
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
// pre-consume quota
|
// pre-consume quota
|
||||||
promptTokens := getPromptTokens(textRequest, meta.Mode)
|
promptTokens := getPromptTokens(c.Request.Context(), textRequest, meta.Mode)
|
||||||
meta.PromptTokens = promptTokens
|
meta.PromptTokens = promptTokens
|
||||||
preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta)
|
preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta)
|
||||||
if bizErr != nil {
|
if bizErr != nil {
|
||||||
|
@ -13,4 +13,5 @@ const (
|
|||||||
AudioTranslation
|
AudioTranslation
|
||||||
// Proxy is a special relay mode for proxying requests to custom upstream
|
// Proxy is a special relay mode for proxying requests to custom upstream
|
||||||
Proxy
|
Proxy
|
||||||
|
ImagesEdits
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user