From 6013219f5b93eee44fef8961d1096635cb0e27cf Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Thu, 29 Feb 2024 16:21:25 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/billing.go | 5 +- controller/channel-test.go | 4 +- controller/midjourney.go | 4 +- controller/model.go | 3 +- controller/relay.go | 2 +- dto/audio.go | 13 ++++ dto/dalle.go | 20 ++++++ dto/midjourney.go | 19 ++++++ dto/{request.go => text_request.go} | 0 dto/{response.go => text_response.go} | 26 ------- main.go | 4 +- relay/channel/adapter.go | 36 ---------- relay/channel/ali/adaptor.go | 6 +- relay/channel/api_request.go | 6 +- relay/channel/baidu/adaptor.go | 6 +- relay/channel/claude/adaptor.go | 6 +- relay/channel/gemini/adaptor.go | 6 +- relay/channel/openai/adaptor.go | 6 +- relay/channel/palm/adaptor.go | 6 +- relay/channel/tencent/adaptor.go | 6 +- relay/channel/xunfei/adaptor.go | 4 +- relay/channel/zhipu/adaptor.go | 6 +- relay/common/relay_info.go | 6 +- relay/common/relay_utils.go | 9 +++ relay/relay-audio.go | 25 +++---- relay/relay-image.go | 57 ++++++++-------- relay/relay-mj.go | 98 ++++++++++++++------------- relay/relay-text.go | 3 +- relay/relay_adaptor.go | 41 +++++++++++ service/token_counter.go | 2 +- 30 files changed, 240 insertions(+), 195 deletions(-) create mode 100644 dto/audio.go create mode 100644 dto/dalle.go create mode 100644 dto/midjourney.go rename dto/{request.go => text_request.go} (100%) rename dto/{response.go => text_response.go} (71%) create mode 100644 relay/relay_adaptor.go diff --git a/controller/billing.go b/controller/billing.go index f26e682..02fb8bd 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -3,6 +3,7 @@ package controller import ( "github.com/gin-gonic/gin" "one-api/common" + "one-api/dto" "one-api/model" ) @@ -27,7 +28,7 @@ func GetSubscription(c *gin.Context) { expiredTime = 0 } if err != nil { - openAIError := OpenAIError{ + openAIError := dto.OpenAIError{ Message: err.Error(), Type: "upstream_error", } @@ -69,7 +70,7 @@ func GetUsage(c *gin.Context) { quota, err = model.GetUserUsedQuota(userId) } if err != nil { - openAIError := OpenAIError{ + openAIError := dto.OpenAIError{ Message: err.Error(), Type: "new_api_error", } diff --git a/controller/channel-test.go b/controller/channel-test.go index 1ea767a..2c24d2f 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -12,7 +12,7 @@ import ( "one-api/common" "one-api/dto" "one-api/model" - relaychannel "one-api/relay/channel" + "one-api/relay" relaycommon "one-api/relay/common" "one-api/relay/constant" "one-api/service" @@ -39,7 +39,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr c.Set("base_url", channel.GetBaseURL()) meta := relaycommon.GenRelayInfo(c) apiType := constant.ChannelType2APIType(channel.Type) - adaptor := relaychannel.GetAdaptor(apiType) + adaptor := relay.GetAdaptor(apiType) if adaptor == nil { return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } diff --git a/controller/midjourney.go b/controller/midjourney.go index 0e28efc..1a42270 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -10,9 +10,9 @@ import ( "log" "net/http" "one-api/common" - "one-api/controller/relay" "one-api/model" relay2 "one-api/relay" + "one-api/service" "strconv" "strings" "time" @@ -223,7 +223,7 @@ func UpdateMidjourneyTaskBulk() { req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/json") req.Header.Set("mj-api-secret", midjourneyChannel.Key) - resp, err := relay.httpClient.Do(req) + resp, err := service.GetHttpClient().Do(req) if err != nil { common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) continue diff --git a/controller/model.go b/controller/model.go index 1721cb7..05d725f 100644 --- a/controller/model.go +++ b/controller/model.go @@ -3,6 +3,7 @@ package controller import ( "fmt" "github.com/gin-gonic/gin" + "one-api/dto" ) // https://platform.openai.com/docs/api-reference/models/list @@ -639,7 +640,7 @@ func RetrieveModel(c *gin.Context) { if model, ok := openAIModelsMap[modelId]; ok { c.JSON(200, model) } else { - openAIError := OpenAIError{ + openAIError := dto.OpenAIError{ Message: fmt.Sprintf("The model '%s' does not exist", modelId), Type: "invalid_request_error", Param: "model", diff --git a/controller/relay.go b/controller/relay.go index e22bb2d..a79e46c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -26,7 +26,7 @@ func Relay(c *gin.Context) { case relayconstant.RelayModeAudioTranslation: fallthrough case relayconstant.RelayModeAudioTranscription: - err = relay.RelayAudioHelper(c, relayMode) + err = relay.AudioHelper(c, relayMode) default: err = relay.TextHelper(c) } diff --git a/dto/audio.go b/dto/audio.go new file mode 100644 index 0000000..c67d678 --- /dev/null +++ b/dto/audio.go @@ -0,0 +1,13 @@ +package dto + +type TextToSpeechRequest struct { + Model string `json:"model" binding:"required"` + Input string `json:"input" binding:"required"` + Voice string `json:"voice" binding:"required"` + Speed float64 `json:"speed"` + ResponseFormat string `json:"response_format"` +} + +type AudioResponse struct { + Text string `json:"text"` +} diff --git a/dto/dalle.go b/dto/dalle.go new file mode 100644 index 0000000..d366051 --- /dev/null +++ b/dto/dalle.go @@ -0,0 +1,20 @@ +package dto + +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` + User string `json:"user,omitempty"` +} + +type ImageResponse struct { + Created int `json:"created"` + Data []struct { + Url string `json:"url"` + B64Json string `json:"b64_json"` + } +} diff --git a/dto/midjourney.go b/dto/midjourney.go new file mode 100644 index 0000000..4c67909 --- /dev/null +++ b/dto/midjourney.go @@ -0,0 +1,19 @@ +package dto + +type MidjourneyRequest struct { + Prompt string `json:"prompt"` + NotifyHook string `json:"notifyHook"` + Action string `json:"action"` + Index int `json:"index"` + State string `json:"state"` + TaskId string `json:"taskId"` + Base64Array []string `json:"base64Array"` + Content string `json:"content"` +} + +type MidjourneyResponse struct { + Code int `json:"code"` + Description string `json:"description"` + Properties interface{} `json:"properties"` + Result string `json:"result"` +} diff --git a/dto/request.go b/dto/text_request.go similarity index 100% rename from dto/request.go rename to dto/text_request.go diff --git a/dto/response.go b/dto/text_response.go similarity index 71% rename from dto/response.go rename to dto/text_response.go index 620c083..752793a 100644 --- a/dto/response.go +++ b/dto/text_response.go @@ -33,14 +33,6 @@ type OpenAIEmbeddingResponse struct { Usage `json:"usage"` } -type ImageResponse struct { - Created int `json:"created"` - Data []struct { - Url string `json:"url"` - B64Json string `json:"b64_json"` - } -} - type ChatCompletionsStreamResponseChoice struct { Delta struct { Content string `json:"content"` @@ -66,21 +58,3 @@ type CompletionsStreamResponse struct { FinishReason string `json:"finish_reason"` } `json:"choices"` } - -type MidjourneyRequest struct { - Prompt string `json:"prompt"` - NotifyHook string `json:"notifyHook"` - Action string `json:"action"` - Index int `json:"index"` - State string `json:"state"` - TaskId string `json:"taskId"` - Base64Array []string `json:"base64Array"` - Content string `json:"content"` -} - -type MidjourneyResponse struct { - Code int `json:"code"` - Description string `json:"description"` - Properties interface{} `json:"properties"` - Result string `json:"result"` -} diff --git a/main.go b/main.go index 94496a4..234bca7 100644 --- a/main.go +++ b/main.go @@ -12,8 +12,8 @@ import ( "one-api/controller" "one-api/middleware" "one-api/model" - "one-api/relay/common" "one-api/router" + "one-api/service" "os" "strconv" @@ -106,7 +106,7 @@ func main() { common.SysLog("pprof enabled") } - common.InitTokenEncoders() + service.InitTokenEncoders() // Initialize HTTP server server := gin.New() diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index ee696fd..d3886d5 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -5,17 +5,7 @@ import ( "io" "net/http" "one-api/dto" - "one-api/relay/channel/ali" - "one-api/relay/channel/baidu" - "one-api/relay/channel/claude" - "one-api/relay/channel/gemini" - "one-api/relay/channel/openai" - "one-api/relay/channel/palm" - "one-api/relay/channel/tencent" - "one-api/relay/channel/xunfei" - "one-api/relay/channel/zhipu" relaycommon "one-api/relay/common" - "one-api/relay/constant" ) type Adaptor interface { @@ -29,29 +19,3 @@ type Adaptor interface { GetModelList() []string GetChannelName() string } - -func GetAdaptor(apiType int) Adaptor { - switch apiType { - //case constant.APITypeAIProxyLibrary: - // return &aiproxy.Adaptor{} - case constant.APITypeAli: - return &ali.Adaptor{} - case constant.APITypeAnthropic: - return &claude.Adaptor{} - case constant.APITypeBaidu: - return &baidu.Adaptor{} - case constant.APITypeGemini: - return &gemini.Adaptor{} - case constant.APITypeOpenAI: - return &openai.Adaptor{} - case constant.APITypePaLM: - return &palm.Adaptor{} - case constant.APITypeTencent: - return &tencent.Adaptor{} - case constant.APITypeXunfei: - return &xunfei.Adaptor{} - case constant.APITypeZhipu: - return &zhipu.Adaptor{} - } - return nil -} diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index b79299a..bfe83db 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -7,7 +7,7 @@ import ( "io" "net/http" "one-api/dto" - relaychannel "one-api/relay/channel" + "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/relay/constant" ) @@ -28,7 +28,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { - relaychannel.SetupApiRequestHeader(info, c, req) + channel.SetupApiRequestHeader(info, c, req) req.Header.Set("Authorization", "Bearer "+info.ApiKey) if info.IsStream { req.Header.Set("X-DashScope-SSE", "enable") @@ -54,7 +54,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { - return relaychannel.DoApiRequest(a, c, info, requestBody) + return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index b0ef212..ef82645 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -6,11 +6,11 @@ import ( "github.com/gin-gonic/gin" "io" "net/http" - relaycommon "one-api/relay/common" + "one-api/relay/common" "one-api/service" ) -func SetupApiRequestHeader(info *relaycommon.RelayInfo, c *gin.Context, req *http.Request) { +func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) { req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) if info.IsStream && c.Request.Header.Get("Accept") == "" { @@ -18,7 +18,7 @@ func SetupApiRequestHeader(info *relaycommon.RelayInfo, c *gin.Context, req *htt } } -func DoApiRequest(a Adaptor, c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { fullRequestURL, err := a.GetRequestURL(info) if err != nil { return nil, fmt.Errorf("get request url failed: %w", err) diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index a07fccd..d2571dc 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -6,7 +6,7 @@ import ( "io" "net/http" "one-api/dto" - relaychannel "one-api/relay/channel" + "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/relay/constant" ) @@ -46,7 +46,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { - relaychannel.SetupApiRequestHeader(info, c, req) + channel.SetupApiRequestHeader(info, c, req) req.Header.Set("Authorization", "Bearer "+info.ApiKey) return nil } @@ -66,7 +66,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { - return relaychannel.DoApiRequest(a, c, info, requestBody) + return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 130024b..a7245ee 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -7,7 +7,7 @@ import ( "io" "net/http" "one-api/dto" - relaychannel "one-api/relay/channel" + "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" ) @@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { - relaychannel.SetupApiRequestHeader(info, c, req) + channel.SetupApiRequestHeader(info, c, req) req.Header.Set("x-api-key", info.ApiKey) anthropicVersion := c.Request.Header.Get("anthropic-version") if anthropicVersion == "" { @@ -42,7 +42,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { - return relaychannel.DoApiRequest(a, c, info, requestBody) + return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 5a200eb..a613ae2 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -7,7 +7,7 @@ import ( "io" "net/http" "one-api/dto" - relaychannel "one-api/relay/channel" + "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" ) @@ -28,7 +28,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { - relaychannel.SetupApiRequestHeader(info, c, req) + channel.SetupApiRequestHeader(info, c, req) req.Header.Set("x-goog-api-key", info.ApiKey) return nil } @@ -41,7 +41,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { - return relaychannel.DoApiRequest(a, c, info, requestBody) + return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index bd01965..be8b621 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -8,7 +8,7 @@ import ( "net/http" "one-api/common" "one-api/dto" - relaychannel "one-api/relay/channel" + "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" "strings" @@ -40,7 +40,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { - relaychannel.SetupApiRequestHeader(info, c, req) + channel.SetupApiRequestHeader(info, c, req) if info.ChannelType == common.ChannelTypeAzure { req.Header.Set("api-key", info.ApiKey) return nil @@ -61,7 +61,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { - return relaychannel.DoApiRequest(a, c, info, requestBody) + return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 2a5f017..6458858 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -7,7 +7,7 @@ import ( "io" "net/http" "one-api/dto" - relaychannel "one-api/relay/channel" + "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" ) @@ -23,7 +23,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { - relaychannel.SetupApiRequestHeader(info, c, req) + channel.SetupApiRequestHeader(info, c, req) req.Header.Set("x-goog-api-key", info.ApiKey) return nil } @@ -36,7 +36,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { - return relaychannel.DoApiRequest(a, c, info, requestBody) + return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 58b5e0d..7571659 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -7,7 +7,7 @@ import ( "io" "net/http" "one-api/dto" - relaychannel "one-api/relay/channel" + "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" "strings" @@ -25,7 +25,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { - relaychannel.SetupApiRequestHeader(info, c, req) + channel.SetupApiRequestHeader(info, c, req) req.Header.Set("Authorization", a.Sign) req.Header.Set("X-TC-Action", info.UpstreamModelName) return nil @@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { - return relaychannel.DoApiRequest(a, c, info, requestBody) + return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 9baebd7..79a4b12 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -6,7 +6,7 @@ import ( "io" "net/http" "one-api/dto" - relaychannel "one-api/relay/channel" + "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" "strings" @@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { - relaychannel.SetupApiRequestHeader(info, c, req) + channel.SetupApiRequestHeader(info, c, req) return nil } diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 6fd3047..d437f1b 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -7,7 +7,7 @@ import ( "io" "net/http" "one-api/dto" - relaychannel "one-api/relay/channel" + "one-api/relay/channel" relaycommon "one-api/relay/common" ) @@ -26,7 +26,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { - relaychannel.SetupApiRequestHeader(info, c, req) + channel.SetupApiRequestHeader(info, c, req) token := getZhipuToken(info.ApiKey) req.Header.Set("Authorization", token) return nil @@ -40,7 +40,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { - return relaychannel.DoApiRequest(a, c, info, requestBody) + return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 62302a0..8051cec 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -56,9 +56,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { if info.BaseUrl == "" { info.BaseUrl = common.ChannelBaseURLs[channelType] } - //if info.ChannelType == common.ChannelTypeAzure { - // info.ApiVersion = GetAzureAPIVersion(c) - //} + if info.ChannelType == common.ChannelTypeAzure { + info.ApiVersion = GetAzureAPIVersion(c) + } return info } diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 0ab38cb..8e75d24 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -66,3 +66,12 @@ func GetAPIVersion(c *gin.Context) string { } return apiVersion } + +func GetAzureAPIVersion(c *gin.Context) string { + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + return apiVersion +} diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 1a62fff..3a0841f 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -10,9 +10,10 @@ import ( "io" "net/http" "one-api/common" - "one-api/controller" "one-api/dto" "one-api/model" + relaycommon "one-api/relay/common" + relayconstant "one-api/relay/constant" "one-api/service" "strings" "time" @@ -27,7 +28,7 @@ var availableVoices = []string{ "shimmer", } -func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWithStatusCode { +func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") @@ -35,14 +36,14 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith group := c.GetString("group") startTime := time.Now() - var audioRequest AudioRequest + var audioRequest dto.TextToSpeechRequest if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { err := common.UnmarshalBodyReusable(c, &audioRequest) if err != nil { return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } } else { - audioRequest = AudioRequest{ + audioRequest = dto.TextToSpeechRequest{ Model: "whisper-1", } } @@ -109,10 +110,10 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith baseURL = c.GetString("base_url") } - fullRequestURL := common.getFullRequestURL(baseURL, requestURL, channelType) - if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType) + if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - apiVersion := common.GetAPIVersion(c) + apiVersion := relaycommon.GetAzureAPIVersion(c) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion) } @@ -123,7 +124,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { + if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") @@ -136,7 +137,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - resp, err := controller.httpClient.Do(req) + resp, err := service.GetHttpClient().Do(req) if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } @@ -151,7 +152,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith } if resp.StatusCode != http.StatusOK { - return common.relayErrorHandler(resp) + return relaycommon.RelayErrorHandler(resp) } var audioResponse dto.AudioResponse @@ -162,10 +163,10 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith quota := 0 var promptTokens = 0 if strings.HasPrefix(audioRequest.Model, "tts-1") { - quota = service.countAudioToken(audioRequest.Input, audioRequest.Model) + quota = service.CountAudioToken(audioRequest.Input, audioRequest.Model) promptTokens = quota } else { - quota = service.countAudioToken(audioResponse.Text, audioRequest.Model) + quota = service.CountAudioToken(audioResponse.Text, audioRequest.Model) } quota = int(float64(quota) * ratio) if ratio != 0 && quota <= 0 { diff --git a/relay/relay-image.go b/relay/relay-image.go index e717c3e..3065496 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -10,15 +10,16 @@ import ( "io" "net/http" "one-api/common" - "one-api/controller" "one-api/dto" "one-api/model" - "one-api/relay/common" + relaycommon "one-api/relay/common" + relayconstant "one-api/relay/constant" + "one-api/service" "strings" "time" ) -func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { +func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") @@ -31,7 +32,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if consumeQuota { err := common.UnmarshalBodyReusable(c, &imageRequest) if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } } @@ -46,29 +47,29 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } // Prompt validation if imageRequest.Prompt == "" { - return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) + return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) } if strings.Contains(imageRequest.Size, "×") { - return errorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest) + return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest) } // Not "256x256", "512x512", or "1024x1024" if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest) + return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest) } } else if imageRequest.Model == "dall-e-3" { if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { - return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest) + return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest) } if imageRequest.N != 1 { - return errorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest) + return service.OpenAIErrorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest) } } // N should between 1 and 10 if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { - return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) + return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) } // map model name @@ -78,7 +79,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) if err != nil { - return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) } if modelMap[imageRequest.Model] != "" { imageRequest.Model = modelMap[imageRequest.Model] @@ -90,10 +91,10 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") } - fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) - if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations { + fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType) + if channelType == common.ChannelTypeAzure && relayMode == relayconstant.RelayModeImagesGenerations { // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api - apiVersion := common.GetAPIVersion(c) + apiVersion := relaycommon.GetAPIVersion(c) // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion) } @@ -101,7 +102,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) } else { @@ -136,12 +137,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N if consumeQuota && userQuota-quota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { - return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } token := c.Request.Header.Get("Authorization") @@ -154,25 +155,25 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - resp, err := controller.httpClient.Do(req) + resp, err := service.GetHttpClient().Do(req) if err != nil { - return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } err = req.Body.Close() if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } err = c.Request.Body.Close() if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } if resp.StatusCode != http.StatusOK { - return relayErrorHandler(resp) + return relaycommon.RelayErrorHandler(resp) } - var textResponse ImageResponse + var textResponse dto.ImageResponse defer func(ctx context.Context) { useTimeSeconds := time.Now().Unix() - startTime.Unix() if consumeQuota { @@ -202,15 +203,15 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } err = json.Unmarshal(responseBody, &textResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) @@ -223,11 +224,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode _, err = io.Copy(c.Writer, resp.Body) if err != nil { - return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } return nil } diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 36114bf..b2b9926 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -9,8 +9,10 @@ import ( "log" "net/http" "one-api/common" - "one-api/controller" + "one-api/dto" "one-api/model" + relayconstant "one-api/relay/constant" + "one-api/service" "strconv" "strings" "time" @@ -105,11 +107,11 @@ func RelayMidjourneyImage(c *gin.Context) { return } -func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse { +func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse { var midjRequest Midjourney err := common.UnmarshalBodyReusable(c, &midjRequest) if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "bind_request_body_failed", Properties: nil, @@ -118,7 +120,7 @@ func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse { } midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId) if midjourneyTask == nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "midjourney_task_not_found", Properties: nil, @@ -136,7 +138,7 @@ func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse { midjourneyTask.FailReason = midjRequest.FailReason err = midjourneyTask.Update() if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "update_midjourney_task_failed", } @@ -168,16 +170,16 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo return } -func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse { +func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse { userId := c.GetInt("id") var err error var respBody []byte switch relayMode { - case RelayModeMidjourneyTaskFetch: + case relayconstant.RelayModeMidjourneyTaskFetch: taskId := c.Param("id") originTask := model.GetByMJId(userId, taskId) if originTask == nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "task_no_found", } @@ -185,18 +187,18 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse { midjourneyTask := getMidjourneyTaskModel(c, originTask) respBody, err = json.Marshal(midjourneyTask) if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "unmarshal_response_body_failed", } } - case RelayModeMidjourneyTaskFetchByCondition: + case relayconstant.RelayModeMidjourneyTaskFetchByCondition: var condition = struct { IDs []string `json:"ids"` }{} err = c.BindJSON(&condition) if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "do_request_failed", } @@ -214,7 +216,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse { } respBody, err = json.Marshal(tasks) if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "unmarshal_response_body_failed", } @@ -225,7 +227,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse { _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "copy_response_body_failed", } @@ -245,7 +247,7 @@ const ( MJSubmitActionUpscale = "UPSCALE" // 放大 ) -func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { +func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse { imageModel := "midjourney" tokenId := c.GetInt("token_id") @@ -254,60 +256,60 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") channelId := c.GetInt("channel_id") - var midjRequest MidjourneyRequest + var midjRequest dto.MidjourneyRequest if consumeQuota { err := common.UnmarshalBodyReusable(c, &midjRequest) if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "bind_request_body_failed", } } } - if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 + if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 if midjRequest.Prompt == "" { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "prompt_is_required", } } midjRequest.Action = "IMAGINE" - } else if relayMode == RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 + } else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 midjRequest.Action = "DESCRIBE" - } else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 + } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 midjRequest.Action = "BLEND" } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果 mjId := "" - if relayMode == RelayModeMidjourneyChange { + if relayMode == relayconstant.RelayModeMidjourneyChange { if midjRequest.TaskId == "" { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "taskId_is_required", } } else if midjRequest.Action == "" { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "action_is_required", } } else if midjRequest.Index == 0 { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "index_can_only_be_1_2_3_4", } } //action = midjRequest.Action mjId = midjRequest.TaskId - } else if relayMode == RelayModeMidjourneySimpleChange { + } else if relayMode == relayconstant.RelayModeMidjourneySimpleChange { if midjRequest.Content == "" { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "content_is_required", } } params := convertSimpleChangeParams(midjRequest.Content) if params == nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "content_parse_failed", } @@ -318,25 +320,25 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { originTask := model.GetByMJId(userId, mjId) if originTask == nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "task_no_found", } } else if originTask.Action == "UPSCALE" { //return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest). - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "upscale_task_can_not_be_change", } } else if originTask.Status != "SUCCESS" { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "task_status_is_not_success", } } else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理 channel, err := model.GetChannelById(originTask.ChannelId, false) if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "channel_not_found", } @@ -356,7 +358,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { err := json.Unmarshal([]byte(modelMapping), &modelMap) if err != nil { //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "unmarshal_model_mapping_failed", } @@ -383,7 +385,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { if isModelMapped { jsonStr, err := json.Marshal(midjRequest) if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "marshal_text_request_failed", } @@ -407,7 +409,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { ratio := modelPrice * groupRatio userQuota, err := model.CacheGetUserQuota(userId) if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: err.Error(), } @@ -415,7 +417,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { quota := int(ratio * common.QuotaPerUnit) if consumeQuota && userQuota-quota < 0 { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "quota_not_enough", } @@ -423,7 +425,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "create_request_failed", } @@ -442,9 +444,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { log.Printf("request header: %s", req.Header) log.Printf("request body: %s", midjRequest.Prompt) - resp, err := controller.httpClient.Do(req) + resp, err := service.GetHttpClient().Do(req) if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "do_request_failed", } @@ -452,19 +454,19 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { err = req.Body.Close() if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "close_request_body_failed", } } err = c.Request.Body.Close() if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "close_request_body_failed", } } - var midjResponse MidjourneyResponse + var midjResponse dto.MidjourneyResponse defer func(ctx context.Context) { if consumeQuota { @@ -493,14 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "read_response_body_failed", } } err = resp.Body.Close() if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "close_response_body_failed", } @@ -510,13 +512,13 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { log.Printf("responseBody: %s", string(responseBody)) log.Printf("midjResponse: %v", midjResponse) if resp.StatusCode != 200 { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode), } } if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "unmarshal_response_body_failed", } @@ -579,7 +581,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { err = midjourneyTask.Insert() if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "insert_midjourney_task_failed", } @@ -600,14 +602,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { _, err = io.Copy(c.Writer, resp.Body) if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "copy_response_body_failed", } } err = resp.Body.Close() if err != nil { - return &MidjourneyResponse{ + return &dto.MidjourneyResponse{ Code: 4, Description: "close_response_body_failed", } diff --git a/relay/relay-text.go b/relay/relay-text.go index 4c29b04..d864c84 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -11,7 +11,6 @@ import ( "one-api/common" "one-api/dto" "one-api/model" - relaychannel "one-api/relay/channel" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/service" @@ -119,7 +118,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return openaiErr } - adaptor := relaychannel.GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) } diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go new file mode 100644 index 0000000..d1e613e --- /dev/null +++ b/relay/relay_adaptor.go @@ -0,0 +1,41 @@ +package relay + +import ( + "one-api/relay/channel" + "one-api/relay/channel/ali" + "one-api/relay/channel/baidu" + "one-api/relay/channel/claude" + "one-api/relay/channel/gemini" + "one-api/relay/channel/openai" + "one-api/relay/channel/palm" + "one-api/relay/channel/tencent" + "one-api/relay/channel/xunfei" + "one-api/relay/channel/zhipu" + "one-api/relay/constant" +) + +func GetAdaptor(apiType int) channel.Adaptor { + switch apiType { + //case constant.APITypeAIProxyLibrary: + // return &aiproxy.Adaptor{} + case constant.APITypeAli: + return &ali.Adaptor{} + case constant.APITypeAnthropic: + return &claude.Adaptor{} + case constant.APITypeBaidu: + return &baidu.Adaptor{} + case constant.APITypeGemini: + return &gemini.Adaptor{} + case constant.APITypeOpenAI: + return &openai.Adaptor{} + case constant.APITypePaLM: + return &palm.Adaptor{} + case constant.APITypeTencent: + return &tencent.Adaptor{} + case constant.APITypeXunfei: + return &xunfei.Adaptor{} + case constant.APITypeZhipu: + return &zhipu.Adaptor{} + } + return nil +} diff --git a/service/token_counter.go b/service/token_counter.go index 179eccd..de476f1 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -201,7 +201,7 @@ func CountTokenInput(input any, model string) int { return 0 } -func countAudioToken(text string, model string) int { +func CountAudioToken(text string, model string) int { if strings.HasPrefix(model, "tts") { return utf8.RuneCountInString(text) } else {