From bcc7f3edb28f390c19e368281643a472f3c802b2 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Tue, 16 Jul 2024 22:07:10 +0800 Subject: [PATCH 1/3] refactor: audio relay --- common/str.go | 73 ++++++++ common/utils.go | 60 ------- controller/channel-test.go | 2 +- controller/model.go | 2 +- dto/audio.go | 33 +++- middleware/distributor.go | 20 ++- relay/channel/adapter.go | 5 +- relay/channel/ali/adaptor.go | 10 +- relay/channel/api_request.go | 36 +++- relay/channel/aws/adaptor.go | 11 +- relay/channel/baidu/adaptor.go | 11 +- relay/channel/claude/adaptor.go | 11 +- relay/channel/cloudflare/adaptor.go | 11 +- relay/channel/cohere/adaptor.go | 12 +- relay/channel/dify/adaptor.go | 11 +- relay/channel/gemini/adaptor.go | 11 +- relay/channel/jina/adaptor.go | 11 +- relay/channel/ollama/adaptor.go | 11 +- relay/channel/openai/adaptor.go | 84 ++++++++-- relay/channel/openai/relay-openai.go | 137 +++++++++++++++ relay/channel/palm/adaptor.go | 11 +- relay/channel/perplexity/adaptor.go | 11 +- relay/channel/tencent/adaptor.go | 11 +- relay/channel/xunfei/adaptor.go | 11 +- relay/channel/zhipu/adaptor.go | 11 +- relay/channel/zhipu_4v/adaptor.go | 11 +- relay/constant/relay_mode.go | 10 +- relay/relay-audio.go | 239 ++++++++------------------- relay/relay-text.go | 8 +- relay/relay_rerank.go | 2 +- 30 files changed, 567 insertions(+), 320 deletions(-) create mode 100644 common/str.go diff --git a/common/str.go b/common/str.go new file mode 100644 index 0000000..d61adb1 --- /dev/null +++ b/common/str.go @@ -0,0 +1,73 @@ +package common + +import ( + "encoding/json" + "math/rand" + "strconv" + "unsafe" +) + +func GetStringIfEmpty(str string, defaultValue string) string { + if str == "" { + return defaultValue + } + return str +} + +func GetRandomString(length int) string { + //rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + return string(key) +} + +func MapToJsonStr(m map[string]interface{}) string { + bytes, err := json.Marshal(m) + if err != nil { + return "" + } + return string(bytes) +} + +func MapToJsonStrFloat(m map[string]float64) string { + bytes, err := json.Marshal(m) + if err != nil { + return "" + } + return string(bytes) +} + +func StrToMap(str string) map[string]interface{} { + m := make(map[string]interface{}) + err := json.Unmarshal([]byte(str), &m) + if err != nil { + return nil + } + return m +} + +func String2Int(str string) int { + num, err := strconv.Atoi(str) + if err != nil { + return 0 + } + return num +} + +func StringsContains(strs []string, str string) bool { + for _, s := range strs { + if s == str { + return true + } + } + return false +} + +// StringToByteSlice []byte only read, panic on append +func StringToByteSlice(s string) []byte { + tmp1 := (*[2]uintptr)(unsafe.Pointer(&s)) + tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} + return *(*[]byte)(unsafe.Pointer(&tmp2)) +} diff --git a/common/utils.go b/common/utils.go index 3e047c4..3d95508 100644 --- a/common/utils.go +++ b/common/utils.go @@ -1,7 +1,6 @@ package common import ( - "encoding/json" "fmt" "github.com/google/uuid" "html/template" @@ -13,7 +12,6 @@ import ( "strconv" "strings" "time" - "unsafe" ) func OpenBrowser(url string) { @@ -159,15 +157,6 @@ func GenerateKey() string { return string(key) } -func GetRandomString(length int) string { - //rand.Seed(time.Now().UnixNano()) - key := make([]byte, length) - for i := 0; i < length; i++ { - key[i] = keyChars[rand.Intn(len(keyChars))] - } - return string(key) -} - func GetRandomInt(max int) int { //rand.Seed(time.Now().UnixNano()) return rand.Intn(max) @@ -194,56 +183,7 @@ func MessageWithRequestId(message string, id string) string { return fmt.Sprintf("%s (request id: %s)", message, id) } -func String2Int(str string) int { - num, err := strconv.Atoi(str) - if err != nil { - return 0 - } - return num -} - -func StringsContains(strs []string, str string) bool { - for _, s := range strs { - if s == str { - return true - } - } - return false -} - -// StringToByteSlice []byte only read, panic on append -func StringToByteSlice(s string) []byte { - tmp1 := (*[2]uintptr)(unsafe.Pointer(&s)) - tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} - return *(*[]byte)(unsafe.Pointer(&tmp2)) -} - func RandomSleep() { // Sleep for 0-3000 ms time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond) } - -func MapToJsonStr(m map[string]interface{}) string { - bytes, err := json.Marshal(m) - if err != nil { - return "" - } - return string(bytes) -} - -func MapToJsonStrFloat(m map[string]float64) string { - bytes, err := json.Marshal(m) - if err != nil { - return "" - } - return string(bytes) -} - -func StrToMap(str string) map[string]interface{} { - m := make(map[string]interface{}) - err := json.Unmarshal([]byte(str), &m) - if err != nil { - return nil - } - return m -} diff --git a/controller/channel-test.go b/controller/channel-test.go index 4ad7457..e1af673 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -85,7 +85,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr meta.UpstreamModelName = testModel common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) - adaptor.Init(meta, *request) + adaptor.Init(meta) convertedRequest, err := adaptor.ConvertRequest(c, meta, request) if err != nil { diff --git a/controller/model.go b/controller/model.go index 7e3a321..6b4a878 100644 --- a/controller/model.go +++ b/controller/model.go @@ -131,7 +131,7 @@ func init() { } meta := &relaycommon.RelayInfo{ChannelType: i} adaptor := relay.GetAdaptor(apiType) - adaptor.Init(meta, dto.GeneralOpenAIRequest{}) + adaptor.Init(meta) channelId2Models[i] = adaptor.GetModelList() } } diff --git a/dto/audio.go b/dto/audio.go index c67d678..c36b3da 100644 --- a/dto/audio.go +++ b/dto/audio.go @@ -1,13 +1,34 @@ package dto -type TextToSpeechRequest struct { - Model string `json:"model" binding:"required"` - Input string `json:"input" binding:"required"` - Voice string `json:"voice" binding:"required"` - Speed float64 `json:"speed"` - ResponseFormat string `json:"response_format"` +type AudioRequest struct { + Model string `json:"model"` + Input string `json:"input"` + Voice string `json:"voice"` + Speed float64 `json:"speed,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` } type AudioResponse struct { Text string `json:"text"` } + +type WhisperVerboseJSONResponse struct { + Task string `json:"task,omitempty"` + Language string `json:"language,omitempty"` + Duration float64 `json:"duration,omitempty"` + Text string `json:"text,omitempty"` + Segments []Segment `json:"segments,omitempty"` +} + +type Segment struct { + Id int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogprob float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` +} diff --git a/middleware/distributor.go b/middleware/distributor.go index 9f75207..2552f29 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -154,18 +154,20 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } } if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - if modelRequest.Model == "" { - modelRequest.Model = "dall-e" - } + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e") } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - if modelRequest.Model == "" { - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { - modelRequest.Model = "tts-1" - } else { - modelRequest.Model = "whisper-1" - } + relayMode := relayconstant.RelayModeAudioSpeech + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1") + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") + relayMode = relayconstant.RelayModeAudioTranslation + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") + relayMode = relayconstant.RelayModeAudioTranscription } + c.Set("relay_mode", relayMode) } return &modelRequest, shouldSelectChannel, nil } diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index 7064b88..870b2b0 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -10,12 +10,13 @@ import ( type Adaptor interface { // Init IsStream bool - Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) - InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) + Init(info *relaycommon.RelayInfo) GetRequestURL(info *relaycommon.RelayInfo) (string, error) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) + ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) + ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) GetModelList() []string diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index e03d29f..88990d1 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -15,11 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index ab1131f..423a91d 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -7,14 +7,19 @@ import ( "io" "net/http" "one-api/relay/common" + "one-api/relay/constant" "one-api/service" ) func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) { - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - if info.IsStream && c.Request.Header.Get("Accept") == "" { - req.Header.Set("Accept", "text/event-stream") + if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { + // multipart/form-data + } else { + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + if info.IsStream && c.Request.Header.Get("Accept") == "" { + req.Header.Set("Accept", "text/event-stream") + } } } @@ -38,6 +43,29 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody return resp, nil } +func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { + fullRequestURL, err := a.GetRequestURL(info) + if err != nil { + return nil, fmt.Errorf("get request url failed: %w", err) + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return nil, fmt.Errorf("new request failed: %w", err) + } + // set form data + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + + err = a.SetupRequestHeader(c, req, info) + if err != nil { + return nil, fmt.Errorf("setup request header failed: %w", err) + } + resp, err := doRequest(c, req) + if err != nil { + return nil, fmt.Errorf("do request failed: %w", err) + } + return resp, nil +} + func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) { resp, err := service.GetHttpClient().Do(req) if err != nil { diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 8214777..44a870d 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -20,12 +20,17 @@ type Adaptor struct { RequestMode int } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { if strings.HasPrefix(info.UpstreamModelName, "claude-3") { a.RequestMode = RequestModeMessage } else { diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 40a0696..cc0be56 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -16,12 +16,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 8e4c75d..0544695 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -21,12 +21,17 @@ type Adaptor struct { RequestMode int } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { if strings.HasPrefix(info.UpstreamModelName, "claude-3") { a.RequestMode = RequestModeMessage } else { diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 53b5a91..2f3c46d 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -15,10 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index 84243aa..3945774 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -1,6 +1,7 @@ package cohere import ( + "errors" "fmt" "github.com/gin-gonic/gin" "io" @@ -14,10 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 8dbe8b8..b582da2 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -14,12 +14,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index f223fbf..de7761a 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -14,10 +14,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } // 定义一个映射,存储模型名称和对应的版本 diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index d0a379a..6a04d08 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -15,10 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 15ced27..540ec85 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -15,10 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 6dc56d0..820c2bc 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -1,10 +1,13 @@ package openai import ( + "bytes" + "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" "io" + "mime/multipart" "net/http" "one-api/common" "one-api/dto" @@ -14,21 +17,16 @@ import ( "one-api/relay/channel/minimax" "one-api/relay/channel/moonshot" relaycommon "one-api/relay/common" + "one-api/relay/constant" "strings" ) type Adaptor struct { - ChannelType int + ChannelType int + ResponseFormat string } -func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { - return nil, nil -} - -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { -} - -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType } @@ -83,15 +81,73 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re return request, nil } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + a.ResponseFormat = request.ResponseFormat + if info.RelayMode == constant.RelayModeAudioSpeech { + jsonData, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("error marshalling object: %w", err) + } + return bytes.NewReader(jsonData), nil + } else { + var requestBody bytes.Buffer + writer := multipart.NewWriter(&requestBody) + + writer.WriteField("model", request.Model) + + // 添加文件字段 + file, header, err := c.Request.FormFile("file") + if err != nil { + return nil, errors.New("file is required") + } + defer file.Close() + + part, err := writer.CreateFormFile("file", header.Filename) + if err != nil { + return nil, errors.New("create form file failed") + } + if _, err := io.Copy(part, file); err != nil { + return nil, errors.New("copy file failed") + } + + // 关闭 multipart 编写器以设置分界线 + writer.Close() + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + return &requestBody, nil + } +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { - return channel.DoApiRequest(a, c, info, requestBody) + if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { + return channel.DoFormRequest(a, c, info, requestBody) + } else { + return channel.DoApiRequest(a, c, info, requestBody) + } } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = OpenaiStreamHandler(c, resp, info) - } else { - err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + switch info.RelayMode { + case constant.RelayModeAudioSpeech: + err, usage = OpenaiTTSHandler(c, resp, info) + case constant.RelayModeAudioTranslation: + fallthrough + case constant.RelayModeAudioTranscription: + err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat) + default: + if info.IsStream { + err, usage = OpenaiStreamHandler(c, resp, info) + } else { + err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + } } return } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index b71fcce..4b27a07 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "encoding/json" + "fmt" "github.com/gin-gonic/gin" "io" "net/http" @@ -224,3 +225,139 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model } return nil, &simpleResponse.Usage } + +func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + usage := &dto.Usage{} + usage.PromptTokens = info.PromptTokens + usage.TotalTokens = info.PromptTokens + + return nil, usage +} + +func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var audioResp dto.AudioResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &audioResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + var text string + switch responseFormat { + case "json": + text, err = getTextFromJSON(responseBody) + case "text": + text, err = getTextFromText(responseBody) + case "srt": + text, err = getTextFromSRT(responseBody) + case "verbose_json": + text, err = getTextFromVerboseJSON(responseBody) + case "vtt": + text, err = getTextFromVTT(responseBody) + } + + usage := &dto.Usage{} + usage.PromptTokens = info.PromptTokens + usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + + return nil, usage +} + +func getTextFromVTT(body []byte) (string, error) { + return getTextFromSRT(body) +} + +func getTextFromVerboseJSON(body []byte) (string, error) { + var whisperResponse dto.WhisperVerboseJSONResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} + +func getTextFromSRT(body []byte) (string, error) { + scanner := bufio.NewScanner(strings.NewReader(string(body))) + var builder strings.Builder + var textLine bool + for scanner.Scan() { + line := scanner.Text() + if textLine { + builder.WriteString(line) + textLine = false + continue + } else if strings.Contains(line, "-->") { + textLine = true + continue + } + } + if err := scanner.Err(); err != nil { + return "", err + } + return builder.String(), nil +} + +func getTextFromText(body []byte) (string, error) { + return strings.TrimSuffix(string(body), "\n"), nil +} + +func getTextFromJSON(body []byte) (string, error) { + var whisperResponse dto.AudioResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 51d1399..d8c4ffb 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -15,12 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 40aa0f4..d3ed222 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -15,12 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 3dd9115..5811c87 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -23,12 +23,17 @@ type Adaptor struct { Timestamp int64 } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { a.Action = "ChatCompletions" a.Version = "2023-09-01" a.Timestamp = common.GetTimestamp() diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index adb054e..f499bec 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -16,12 +16,17 @@ type Adaptor struct { request *dto.GeneralOpenAIRequest } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 09345ca..f98581f 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -14,12 +14,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index bdce639..b34b756 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -15,12 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index ed15b08..a072c74 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -13,6 +13,7 @@ const ( RelayModeModerations RelayModeImagesGenerations RelayModeEdits + RelayModeMidjourneyImagine RelayModeMidjourneyDescribe RelayModeMidjourneyBlend @@ -22,16 +23,19 @@ const ( RelayModeMidjourneyTaskFetch RelayModeMidjourneyTaskImageSeed RelayModeMidjourneyTaskFetchByCondition - RelayModeAudioSpeech - RelayModeAudioTranscription - RelayModeAudioTranslation RelayModeMidjourneyAction RelayModeMidjourneyModal RelayModeMidjourneyShorten RelayModeSwapFace + + RelayModeAudioSpeech // tts + RelayModeAudioTranscription // whisper + RelayModeAudioTranslation // whisper + RelayModeSunoFetch RelayModeSunoFetchByID RelayModeSunoSubmit + RelayModeRerank ) diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 9137721..05b723c 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -1,13 +1,10 @@ package relay import ( - "bytes" - "context" "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" - "io" "net/http" "one-api/common" "one-api/constant" @@ -16,69 +13,71 @@ import ( relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/service" - "strings" - "time" ) -func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { - tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - userId := c.GetInt("id") - group := c.GetString("group") - startTime := time.Now() - - var audioRequest dto.TextToSpeechRequest - if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { - err := common.UnmarshalBodyReusable(c, &audioRequest) - if err != nil { - return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - } else { - audioRequest = dto.TextToSpeechRequest{ - Model: "whisper-1", - } +func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) { + audioRequest := &dto.AudioRequest{} + err := common.UnmarshalBodyReusable(c, audioRequest) + if err != nil { + return nil, err } - //err := common.UnmarshalBodyReusable(c, &audioRequest) - - // request validation - if audioRequest.Model == "" { - return service.OpenAIErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) - } - - if strings.HasPrefix(audioRequest.Model, "tts-1") { - if audioRequest.Voice == "" { - return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest) + switch info.RelayMode { + case relayconstant.RelayModeAudioSpeech: + if audioRequest.Model == "" { + return nil, errors.New("model is required") } - } - var err error - promptTokens := 0 - preConsumedTokens := common.PreConsumedQuota - if strings.HasPrefix(audioRequest.Model, "tts-1") { if constant.ShouldCheckPromptSensitive() { - err = service.CheckSensitiveInput(audioRequest.Input) + err := service.CheckSensitiveInput(audioRequest.Input) if err != nil { - return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest) + return nil, err } } + default: + if audioRequest.Model == "" { + audioRequest.Model = c.PostForm("model") + } + if audioRequest.Model == "" { + return nil, errors.New("model is required") + } + if audioRequest.ResponseFormat == "" { + audioRequest.ResponseFormat = "json" + } + } + return audioRequest, nil +} + +func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { + relayInfo := relaycommon.GenRelayInfo(c) + audioRequest, err := getAndValidAudioRequest(c, relayInfo) + + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error())) + return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest) + } + + promptTokens := 0 + preConsumedTokens := common.PreConsumedQuota + if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model) if err != nil { return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) } preConsumedTokens = promptTokens + relayInfo.PromptTokens = promptTokens } + modelRatio := common.GetModelRatio(audioRequest.Model) - groupRatio := common.GetGroupRatio(group) + groupRatio := common.GetGroupRatio(relayInfo.Group) ratio := modelRatio * groupRatio preConsumedQuota := int(float64(preConsumedTokens) * ratio) - userQuota, err := model.CacheGetUserQuota(userId) + userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) if err != nil { return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) } if userQuota-preConsumedQuota < 0 { return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError) } @@ -88,28 +87,12 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { preConsumedQuota = 0 } if preConsumedQuota > 0 { - userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota) if err != nil { return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } - succeed := false - defer func() { - if succeed { - return - } - if preConsumedQuota > 0 { - // we need to roll back the pre-consumed quota - defer func() { - go func() { - // negative means add quota back for token & user - returnPreConsumedQuota(c, tokenId, userQuota, preConsumedQuota) - }() - }() - } - }() - // map model name modelMapping := c.GetString("model_mapping") if modelMapping != "" { @@ -123,132 +106,42 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { } } - baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) } + adaptor.Init(relayInfo) - fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType) - if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - apiVersion := relaycommon.GetAPIVersion(c) - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion) - } - - requestBody := c.Request.Body - - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest) if err != nil { - return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } - if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - req.Header.Set("api-key", apiKey) - req.ContentLength = c.Request.ContentLength - } else { - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) - } - - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - - resp, err := service.GetHttpClient().Do(req) + resp, err := adaptor.DoRequest(c, relayInfo, ioReader) if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - err = req.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - - if resp.StatusCode != http.StatusOK { - return relaycommon.RelayErrorHandler(resp) - } - succeed = true - - var audioResponse dto.AudioResponse - - defer func(ctx context.Context) { - go func() { - useTimeSeconds := time.Now().Unix() - startTime.Unix() - quota := 0 - if strings.HasPrefix(audioRequest.Model, "tts-1") { - quota = promptTokens - } else { - quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model) - } - quota = int(float64(quota) * ratio) - if ratio != 0 && quota <= 0 { - quota = 1 - } - quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - other := make(map[string]interface{}) - other["model_ratio"] = modelRatio - other["group_ratio"] = groupRatio - model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } - }() - }(c.Request.Context()) - - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - if strings.HasPrefix(audioRequest.Model, "tts-1") { - - } else { - err = json.Unmarshal(responseBody, &audioResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - contains, words := service.SensitiveWordContains(audioResponse.Text) - if contains { - return service.OpenAIErrorWrapper(errors.New("response contains sensitive words: "+strings.Join(words, ", ")), "response_contains_sensitive_words", http.StatusBadRequest) + statusCodeMappingStr := c.GetString("status_code_mapping") + if resp != nil { + if resp.StatusCode != http.StatusOK { + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + openaiErr := service.RelayErrorHandler(resp) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr } } - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) + postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } return nil } diff --git a/relay/relay-text.go b/relay/relay-text.go index ef169fa..0438eba 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -91,7 +91,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { } } relayInfo.UpstreamModelName = textRequest.Model - modelPrice, success := common.GetModelPrice(textRequest.Model, false) + modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false) groupRatio := common.GetGroupRatio(relayInfo.Group) var preConsumedQuota int @@ -112,7 +112,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) } - if !success { + if !getModelPriceSuccess { preConsumedTokens := common.PreConsumedQuota if textRequest.MaxTokens != 0 { preConsumedTokens = promptTokens + int(textRequest.MaxTokens) @@ -150,7 +150,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { if adaptor == nil { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) } - adaptor.Init(relayInfo, *textRequest) + adaptor.Init(relayInfo) var requestBody io.Reader convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest) @@ -187,7 +187,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success) + postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess) return nil } diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go index e32ca88..2fc4854 100644 --- a/relay/relay_rerank.go +++ b/relay/relay_rerank.go @@ -66,7 +66,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode if adaptor == nil { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) } - adaptor.InitRerank(relayInfo, *rerankRequest) + adaptor.Init(relayInfo) convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest) if err != nil { From ebb9b675b6f7d1db1e1eecb3345f20119ae75e1d Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Tue, 16 Jul 2024 23:24:47 +0800 Subject: [PATCH 2/3] feat: support cloudflare audio --- controller/channel-test.go | 2 +- middleware/distributor.go | 2 + relay/channel/cloudflare/adaptor.go | 50 +++++++++++++------ relay/channel/cloudflare/{model.go => dto.go} | 8 +++ relay/channel/cloudflare/relay_cloudflare.go | 35 +++++++++++++ relay/channel/openai/relay-openai.go | 18 ++----- relay/common/relay_utils.go | 33 ------------ relay/relay-audio.go | 1 + relay/relay-image.go | 2 +- service/error.go | 7 ++- 10 files changed, 90 insertions(+), 68 deletions(-) rename relay/channel/cloudflare/{model.go => dto.go} (78%) diff --git a/controller/channel-test.go b/controller/channel-test.go index e1af673..90d02d6 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -102,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr return err, nil } if resp != nil && resp.StatusCode != http.StatusOK { - err := relaycommon.RelayErrorHandler(resp) + err := service.RelayErrorHandler(resp) return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err } usage, respErr := adaptor.DoResponse(c, resp, meta) diff --git a/middleware/distributor.go b/middleware/distributor.go index 2552f29..1ce787e 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -161,9 +161,11 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1") } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model")) modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") relayMode = relayconstant.RelayModeAudioTranslation } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model")) modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") relayMode = relayconstant.RelayModeAudioTranscription } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 2f3c46d..a518da8 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -1,6 +1,7 @@ package cloudflare import ( + "bytes" "errors" "fmt" "github.com/gin-gonic/gin" @@ -15,16 +16,6 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - -func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } @@ -65,11 +56,42 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return request, nil } +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + // 添加文件字段 + file, _, err := c.Request.FormFile("file") + if err != nil { + return nil, errors.New("file is required") + } + defer file.Close() + // 打开临时文件用于保存上传的文件内容 + requestBody := &bytes.Buffer{} + + // 将上传的文件内容复制到临时文件 + if _, err := io.Copy(requestBody, file); err != nil { + return nil, err + } + return requestBody, nil +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = cfStreamHandler(c, resp, info) - } else { - err, usage = cfHandler(c, resp, info) + switch info.RelayMode { + case constant.RelayModeEmbeddings: + fallthrough + case constant.RelayModeChatCompletions: + if info.IsStream { + err, usage = cfStreamHandler(c, resp, info) + } else { + err, usage = cfHandler(c, resp, info) + } + case constant.RelayModeAudioTranslation: + fallthrough + case constant.RelayModeAudioTranscription: + err, usage = cfSTTHandler(c, resp, info) } return } diff --git a/relay/channel/cloudflare/model.go b/relay/channel/cloudflare/dto.go similarity index 78% rename from relay/channel/cloudflare/model.go rename to relay/channel/cloudflare/dto.go index c870813..2f6531c 100644 --- a/relay/channel/cloudflare/model.go +++ b/relay/channel/cloudflare/dto.go @@ -11,3 +11,11 @@ type CfRequest struct { Stream bool `json:"stream,omitempty"` Temperature float64 `json:"temperature,omitempty"` } + +type CfAudioResponse struct { + Result CfSTTResult `json:"result"` +} + +type CfSTTResult struct { + Text string `json:"text"` +} diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index d9319ef..69d6b85 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -119,3 +119,38 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) _, _ = c.Writer.Write(jsonResponse) return nil, usage } + +func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var cfResp CfAudioResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &cfResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + audioResp := &dto.AudioResponse{ + Text: cfResp.Result.Text, + } + + jsonResponse, err := json.Marshal(audioResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + usage := &dto.Usage{} + usage.PromptTokens = info.PromptTokens + usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + + return nil, usage +} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 4b27a07..651e82e 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -165,10 +165,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. service.Done(c) - err := resp.Body.Close() - if err != nil { - common.LogError(c, "close_response_body_failed: "+err.Error()) - } + resp.Body.Close() return nil, usage } @@ -206,11 +203,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if err != nil { return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - + resp.Body.Close() if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range simpleResponse.Choices { @@ -257,7 +250,6 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens usage.TotalTokens = info.PromptTokens - return nil, usage } @@ -290,10 +282,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel if err != nil { return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + resp.Body.Close() var text string switch responseFormat { @@ -313,7 +302,6 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel usage.PromptTokens = info.PromptTokens usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - return nil, usage } diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 9ef9a8b..6daf003 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -1,50 +1,17 @@ package common import ( - "encoding/json" "fmt" "github.com/gin-gonic/gin" _ "image/gif" _ "image/jpeg" _ "image/png" - "io" - "net/http" "one-api/common" - "one-api/dto" - "strconv" "strings" ) var StopFinishReason = "stop" -func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { - OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{ - StatusCode: resp.StatusCode, - Error: dto.OpenAIError{ - Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), - }, - } - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return - } - err = resp.Body.Close() - if err != nil { - return - } - var textResponse dto.TextResponseWithError - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody) - return - } - OpenAIErrorWithStatusCode.Error = textResponse.Error - return -} - func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 05b723c..2a0278e 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -105,6 +105,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { audioRequest.Model = modelMap[audioRequest.Model] } } + relayInfo.UpstreamModelName = audioRequest.Model adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { diff --git a/relay/relay-image.go b/relay/relay-image.go index d83ec26..6d6e4d4 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -180,7 +180,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC } if resp.StatusCode != http.StatusOK { - return relaycommon.RelayErrorHandler(resp) + return service.RelayErrorHandler(resp) } var textResponse dto.ImageResponse diff --git a/service/error.go b/service/error.go index 0f6d472..3410de8 100644 --- a/service/error.go +++ b/service/error.go @@ -56,10 +56,9 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW errWithStatusCode = &dto.OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, Error: dto.OpenAIError{ - Message: "", - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), + Type: "upstream_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), }, } responseBody, err := io.ReadAll(resp.Body) From 86ca533f7ab44a719210e83db2822e388c7980ee Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Tue, 16 Jul 2024 23:40:52 +0800 Subject: [PATCH 3/3] fix: fix bug --- relay/relay-text.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/relay-text.go b/relay/relay-text.go index 0438eba..9e1b9b7 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -300,7 +300,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN } totalTokens := promptTokens + completionTokens var logContent string - if modelPrice == -1 { + if !usePrice { logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) } else { logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)