mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
Merge pull request #367 from Calcium-Ion/audio
feat: support cloudflare tts
This commit is contained in:
commit
0f94ff47b5
73
common/str.go
Normal file
73
common/str.go
Normal file
@ -0,0 +1,73 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func GetStringIfEmpty(str string, defaultValue string) string {
|
||||
if str == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
func GetRandomString(length int) string {
|
||||
//rand.Seed(time.Now().UnixNano())
|
||||
key := make([]byte, length)
|
||||
for i := 0; i < length; i++ {
|
||||
key[i] = keyChars[rand.Intn(len(keyChars))]
|
||||
}
|
||||
return string(key)
|
||||
}
|
||||
|
||||
func MapToJsonStr(m map[string]interface{}) string {
|
||||
bytes, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
func MapToJsonStrFloat(m map[string]float64) string {
|
||||
bytes, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
func StrToMap(str string) map[string]interface{} {
|
||||
m := make(map[string]interface{})
|
||||
err := json.Unmarshal([]byte(str), &m)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func String2Int(str string) int {
|
||||
num, err := strconv.Atoi(str)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
||||
func StringsContains(strs []string, str string) bool {
|
||||
for _, s := range strs {
|
||||
if s == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// StringToByteSlice []byte only read, panic on append
|
||||
func StringToByteSlice(s string) []byte {
|
||||
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
|
||||
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
|
||||
return *(*[]byte)(unsafe.Pointer(&tmp2))
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"html/template"
|
||||
@ -13,7 +12,6 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func OpenBrowser(url string) {
|
||||
@ -159,15 +157,6 @@ func GenerateKey() string {
|
||||
return string(key)
|
||||
}
|
||||
|
||||
func GetRandomString(length int) string {
|
||||
//rand.Seed(time.Now().UnixNano())
|
||||
key := make([]byte, length)
|
||||
for i := 0; i < length; i++ {
|
||||
key[i] = keyChars[rand.Intn(len(keyChars))]
|
||||
}
|
||||
return string(key)
|
||||
}
|
||||
|
||||
func GetRandomInt(max int) int {
|
||||
//rand.Seed(time.Now().UnixNano())
|
||||
return rand.Intn(max)
|
||||
@ -194,56 +183,7 @@ func MessageWithRequestId(message string, id string) string {
|
||||
return fmt.Sprintf("%s (request id: %s)", message, id)
|
||||
}
|
||||
|
||||
func String2Int(str string) int {
|
||||
num, err := strconv.Atoi(str)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
||||
func StringsContains(strs []string, str string) bool {
|
||||
for _, s := range strs {
|
||||
if s == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// StringToByteSlice []byte only read, panic on append
|
||||
func StringToByteSlice(s string) []byte {
|
||||
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
|
||||
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
|
||||
return *(*[]byte)(unsafe.Pointer(&tmp2))
|
||||
}
|
||||
|
||||
func RandomSleep() {
|
||||
// Sleep for 0-3000 ms
|
||||
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
|
||||
}
|
||||
|
||||
func MapToJsonStr(m map[string]interface{}) string {
|
||||
bytes, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
func MapToJsonStrFloat(m map[string]float64) string {
|
||||
bytes, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
func StrToMap(str string) map[string]interface{} {
|
||||
m := make(map[string]interface{})
|
||||
err := json.Unmarshal([]byte(str), &m)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
@ -85,7 +85,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
meta.UpstreamModelName = testModel
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
|
||||
|
||||
adaptor.Init(meta, *request)
|
||||
adaptor.Init(meta)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
|
||||
if err != nil {
|
||||
@ -102,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
return err, nil
|
||||
}
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
err := relaycommon.RelayErrorHandler(resp)
|
||||
err := service.RelayErrorHandler(resp)
|
||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
|
||||
}
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
|
@ -131,7 +131,7 @@ func init() {
|
||||
}
|
||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
adaptor.Init(meta, dto.GeneralOpenAIRequest{})
|
||||
adaptor.Init(meta)
|
||||
channelId2Models[i] = adaptor.GetModelList()
|
||||
}
|
||||
}
|
||||
|
33
dto/audio.go
33
dto/audio.go
@ -1,13 +1,34 @@
|
||||
package dto
|
||||
|
||||
type TextToSpeechRequest struct {
|
||||
Model string `json:"model" binding:"required"`
|
||||
Input string `json:"input" binding:"required"`
|
||||
Voice string `json:"voice" binding:"required"`
|
||||
Speed float64 `json:"speed"`
|
||||
ResponseFormat string `json:"response_format"`
|
||||
type AudioRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input string `json:"input"`
|
||||
Voice string `json:"voice"`
|
||||
Speed float64 `json:"speed,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type AudioResponse struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type WhisperVerboseJSONResponse struct {
|
||||
Task string `json:"task,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Duration float64 `json:"duration,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Segments []Segment `json:"segments,omitempty"`
|
||||
}
|
||||
|
||||
type Segment struct {
|
||||
Id int `json:"id"`
|
||||
Seek int `json:"seek"`
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Text string `json:"text"`
|
||||
Tokens []int `json:"tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
AvgLogprob float64 `json:"avg_logprob"`
|
||||
CompressionRatio float64 `json:"compression_ratio"`
|
||||
NoSpeechProb float64 `json:"no_speech_prob"`
|
||||
}
|
||||
|
@ -159,18 +159,22 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e"
|
||||
}
|
||||
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||
if modelRequest.Model == "" {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
||||
modelRequest.Model = "tts-1"
|
||||
} else {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
relayMode := relayconstant.RelayModeAudioSpeech
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
||||
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1")
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
|
||||
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
|
||||
relayMode = relayconstant.RelayModeAudioTranslation
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
|
||||
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
|
||||
relayMode = relayconstant.RelayModeAudioTranscription
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
}
|
||||
return &modelRequest, shouldSelectChannel, nil
|
||||
}
|
||||
|
@ -10,12 +10,13 @@ import (
|
||||
|
||||
type Adaptor interface {
|
||||
// Init IsStream bool
|
||||
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
|
||||
InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
|
||||
Init(info *relaycommon.RelayInfo)
|
||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
|
||||
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
|
@ -15,11 +15,17 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
@ -7,14 +7,19 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if info.IsStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
||||
// multipart/form-data
|
||||
} else {
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if info.IsStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,6 +43,29 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
fullRequestURL, err := a.GetRequestURL(info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||
}
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
// set form data
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
|
||||
err = a.SetupRequestHeader(c, req, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
resp, err := doRequest(c, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do request failed: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
||||
resp, err := service.GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
|
@ -20,12 +20,17 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||
a.RequestMode = RequestModeMessage
|
||||
} else {
|
||||
|
@ -16,12 +16,17 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
}
|
||||
|
||||
|
@ -21,12 +21,17 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||
a.RequestMode = RequestModeMessage
|
||||
} else {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -15,10 +16,7 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
@ -58,11 +56,42 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
// 添加文件字段
|
||||
file, _, err := c.Request.FormFile("file")
|
||||
if err != nil {
|
||||
return nil, errors.New("file is required")
|
||||
}
|
||||
defer file.Close()
|
||||
// 打开临时文件用于保存上传的文件内容
|
||||
requestBody := &bytes.Buffer{}
|
||||
|
||||
// 将上传的文件内容复制到临时文件
|
||||
if _, err := io.Copy(requestBody, file); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return requestBody, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = cfStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = cfHandler(c, resp, info)
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
fallthrough
|
||||
case constant.RelayModeChatCompletions:
|
||||
if info.IsStream {
|
||||
err, usage = cfStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = cfHandler(c, resp, info)
|
||||
}
|
||||
case constant.RelayModeAudioTranslation:
|
||||
fallthrough
|
||||
case constant.RelayModeAudioTranscription:
|
||||
err, usage = cfSTTHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -11,3 +11,11 @@ type CfRequest struct {
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
|
||||
type CfAudioResponse struct {
|
||||
Result CfSTTResult `json:"result"`
|
||||
}
|
||||
|
||||
type CfSTTResult struct {
|
||||
Text string `json:"text"`
|
||||
}
|
@ -119,3 +119,38 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var cfResp CfAudioResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &cfResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
audioResp := &dto.AudioResponse{
|
||||
Text: cfResp.Result.Text,
|
||||
}
|
||||
|
||||
jsonResponse, err := json.Marshal(audioResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
|
||||
return nil, usage
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
@ -14,10 +15,17 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
@ -14,12 +14,17 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
@ -14,10 +14,17 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
// 定义一个映射,存储模型名称和对应的版本
|
||||
|
@ -15,10 +15,17 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
@ -15,10 +15,17 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
@ -1,10 +1,13 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
@ -14,21 +17,16 @@ import (
|
||||
"one-api/relay/channel/minimax"
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
ChannelType int
|
||||
ChannelType int
|
||||
ResponseFormat string
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
}
|
||||
|
||||
@ -83,15 +81,73 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
a.ResponseFormat = request.ResponseFormat
|
||||
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||||
jsonData, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling object: %w", err)
|
||||
}
|
||||
return bytes.NewReader(jsonData), nil
|
||||
} else {
|
||||
var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
|
||||
writer.WriteField("model", request.Model)
|
||||
|
||||
// 添加文件字段
|
||||
file, header, err := c.Request.FormFile("file")
|
||||
if err != nil {
|
||||
return nil, errors.New("file is required")
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
part, err := writer.CreateFormFile("file", header.Filename)
|
||||
if err != nil {
|
||||
return nil, errors.New("create form file failed")
|
||||
}
|
||||
if _, err := io.Copy(part, file); err != nil {
|
||||
return nil, errors.New("copy file failed")
|
||||
}
|
||||
|
||||
// 关闭 multipart 编写器以设置分界线
|
||||
writer.Close()
|
||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return &requestBody, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
||||
return channel.DoFormRequest(a, c, info, requestBody)
|
||||
} else {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = OpenaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeAudioSpeech:
|
||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||
case constant.RelayModeAudioTranslation:
|
||||
fallthrough
|
||||
case constant.RelayModeAudioTranscription:
|
||||
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
||||
default:
|
||||
if info.IsStream {
|
||||
err, usage = OpenaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -167,10 +168,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
|
||||
service.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
common.LogError(c, "close_response_body_failed: "+err.Error())
|
||||
}
|
||||
resp.Body.Close()
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
@ -208,11 +206,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
resp.Body.Close()
|
||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||
completionTokens := 0
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
@ -227,3 +221,134 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
}
|
||||
return nil, &simpleResponse.Usage
|
||||
}
|
||||
|
||||
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.TotalTokens = info.PromptTokens
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var audioResp dto.AudioResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &audioResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
var text string
|
||||
switch responseFormat {
|
||||
case "json":
|
||||
text, err = getTextFromJSON(responseBody)
|
||||
case "text":
|
||||
text, err = getTextFromText(responseBody)
|
||||
case "srt":
|
||||
text, err = getTextFromSRT(responseBody)
|
||||
case "verbose_json":
|
||||
text, err = getTextFromVerboseJSON(responseBody)
|
||||
case "vtt":
|
||||
text, err = getTextFromVTT(responseBody)
|
||||
}
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func getTextFromVTT(body []byte) (string, error) {
|
||||
return getTextFromSRT(body)
|
||||
}
|
||||
|
||||
func getTextFromVerboseJSON(body []byte) (string, error) {
|
||||
var whisperResponse dto.WhisperVerboseJSONResponse
|
||||
if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
||||
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
||||
}
|
||||
return whisperResponse.Text, nil
|
||||
}
|
||||
|
||||
func getTextFromSRT(body []byte) (string, error) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(string(body)))
|
||||
var builder strings.Builder
|
||||
var textLine bool
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if textLine {
|
||||
builder.WriteString(line)
|
||||
textLine = false
|
||||
continue
|
||||
} else if strings.Contains(line, "-->") {
|
||||
textLine = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
func getTextFromText(body []byte) (string, error) {
|
||||
return strings.TrimSuffix(string(body), "\n"), nil
|
||||
}
|
||||
|
||||
func getTextFromJSON(body []byte) (string, error) {
|
||||
var whisperResponse dto.AudioResponse
|
||||
if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
||||
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
||||
}
|
||||
return whisperResponse.Text, nil
|
||||
}
|
||||
|
@ -15,12 +15,17 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
@ -15,12 +15,17 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
@ -23,12 +23,17 @@ type Adaptor struct {
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.Action = "ChatCompletions"
|
||||
a.Version = "2023-09-01"
|
||||
a.Timestamp = common.GetTimestamp()
|
||||
|
@ -16,12 +16,17 @@ type Adaptor struct {
|
||||
request *dto.GeneralOpenAIRequest
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
@ -14,12 +14,17 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
@ -15,12 +15,17 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
@ -1,50 +1,17 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var StopFinishReason = "stop"
|
||||
|
||||
func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||
OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
Error: dto.OpenAIError{
|
||||
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
},
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var textResponse dto.TextResponseWithError
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody)
|
||||
return
|
||||
}
|
||||
OpenAIErrorWithStatusCode.Error = textResponse.Error
|
||||
return
|
||||
}
|
||||
|
||||
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
|
@ -13,6 +13,7 @@ const (
|
||||
RelayModeModerations
|
||||
RelayModeImagesGenerations
|
||||
RelayModeEdits
|
||||
|
||||
RelayModeMidjourneyImagine
|
||||
RelayModeMidjourneyDescribe
|
||||
RelayModeMidjourneyBlend
|
||||
@ -22,16 +23,19 @@ const (
|
||||
RelayModeMidjourneyTaskFetch
|
||||
RelayModeMidjourneyTaskImageSeed
|
||||
RelayModeMidjourneyTaskFetchByCondition
|
||||
RelayModeAudioSpeech
|
||||
RelayModeAudioTranscription
|
||||
RelayModeAudioTranslation
|
||||
RelayModeMidjourneyAction
|
||||
RelayModeMidjourneyModal
|
||||
RelayModeMidjourneyShorten
|
||||
RelayModeSwapFace
|
||||
|
||||
RelayModeAudioSpeech // tts
|
||||
RelayModeAudioTranscription // whisper
|
||||
RelayModeAudioTranslation // whisper
|
||||
|
||||
RelayModeSunoFetch
|
||||
RelayModeSunoFetchByID
|
||||
RelayModeSunoSubmit
|
||||
|
||||
RelayModeRerank
|
||||
)
|
||||
|
||||
|
@ -1,13 +1,10 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@ -16,69 +13,71 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
startTime := time.Now()
|
||||
|
||||
var audioRequest dto.TextToSpeechRequest
|
||||
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||
err := common.UnmarshalBodyReusable(c, &audioRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
} else {
|
||||
audioRequest = dto.TextToSpeechRequest{
|
||||
Model: "whisper-1",
|
||||
}
|
||||
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
|
||||
audioRequest := &dto.AudioRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, audioRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//err := common.UnmarshalBodyReusable(c, &audioRequest)
|
||||
|
||||
// request validation
|
||||
if audioRequest.Model == "" {
|
||||
return service.OpenAIErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||
if audioRequest.Voice == "" {
|
||||
return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeAudioSpeech:
|
||||
if audioRequest.Model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
}
|
||||
var err error
|
||||
promptTokens := 0
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||
if constant.ShouldCheckPromptSensitive() {
|
||||
err = service.CheckSensitiveInput(audioRequest.Input)
|
||||
err := service.CheckSensitiveInput(audioRequest.Input)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
default:
|
||||
if audioRequest.Model == "" {
|
||||
audioRequest.Model = c.PostForm("model")
|
||||
}
|
||||
if audioRequest.Model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
if audioRequest.ResponseFormat == "" {
|
||||
audioRequest.ResponseFormat = "json"
|
||||
}
|
||||
}
|
||||
return audioRequest, nil
|
||||
}
|
||||
|
||||
func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
||||
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
promptTokens := 0
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
||||
}
|
||||
preConsumedTokens = promptTokens
|
||||
relayInfo.PromptTokens = promptTokens
|
||||
}
|
||||
|
||||
modelRatio := common.GetModelRatio(audioRequest.Model)
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@ -88,28 +87,12 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
preConsumedQuota = 0
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
succeed := false
|
||||
defer func() {
|
||||
if succeed {
|
||||
return
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
// we need to roll back the pre-consumed quota
|
||||
defer func() {
|
||||
go func() {
|
||||
// negative means add quota back for token & user
|
||||
returnPreConsumedQuota(c, tokenId, userQuota, preConsumedQuota)
|
||||
}()
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
if modelMapping != "" {
|
||||
@ -122,133 +105,44 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
audioRequest.Model = modelMap[audioRequest.Model]
|
||||
}
|
||||
}
|
||||
relayInfo.UpstreamModelName = audioRequest.Model
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
|
||||
if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||
apiVersion := relaycommon.GetAPIVersion(c)
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
|
||||
}
|
||||
|
||||
requestBody := c.Request.Body
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
req.Header.Set("api-key", apiKey)
|
||||
req.ContentLength = c.Request.ContentLength
|
||||
} else {
|
||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
|
||||
resp, err := service.GetHttpClient().Do(req)
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return relaycommon.RelayErrorHandler(resp)
|
||||
}
|
||||
succeed = true
|
||||
|
||||
var audioResponse dto.AudioResponse
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
go func() {
|
||||
useTimeSeconds := time.Now().Unix() - startTime.Unix()
|
||||
quota := 0
|
||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||
quota = promptTokens
|
||||
} else {
|
||||
quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
|
||||
}
|
||||
quota = int(float64(quota) * ratio)
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
other := make(map[string]interface{})
|
||||
other["model_ratio"] = modelRatio
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||
|
||||
} else {
|
||||
err = json.Unmarshal(responseBody, &audioResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
contains, words := service.SensitiveWordContains(audioResponse.Text)
|
||||
if contains {
|
||||
return service.OpenAIErrorWrapper(errors.New("response contains sensitive words: "+strings.Join(words, ", ")), "response_contains_sensitive_words", http.StatusBadRequest)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
if resp != nil {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
openaiErr := service.RelayErrorHandler(resp)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false)
|
||||
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -180,7 +180,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return relaycommon.RelayErrorHandler(resp)
|
||||
return service.RelayErrorHandler(resp)
|
||||
}
|
||||
|
||||
var textResponse dto.ImageResponse
|
||||
|
@ -91,7 +91,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
}
|
||||
}
|
||||
relayInfo.UpstreamModelName = textRequest.Model
|
||||
modelPrice, success := common.GetModelPrice(textRequest.Model, false)
|
||||
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
@ -112,7 +112,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if !getModelPriceSuccess {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if textRequest.MaxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
|
||||
@ -150,7 +150,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(relayInfo, *textRequest)
|
||||
adaptor.Init(relayInfo)
|
||||
var requestBody io.Reader
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
|
||||
@ -187,7 +187,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success)
|
||||
postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -300,7 +300,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
||||
}
|
||||
totalTokens := promptTokens + completionTokens
|
||||
var logContent string
|
||||
if modelPrice == -1 {
|
||||
if !usePrice {
|
||||
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
|
||||
} else {
|
||||
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
||||
|
@ -66,7 +66,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.InitRerank(relayInfo, *rerankRequest)
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
|
||||
if err != nil {
|
||||
|
@ -56,10 +56,9 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW
|
||||
errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
Error: dto.OpenAIError{
|
||||
Message: "",
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
},
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
Loading…
Reference in New Issue
Block a user