mirror of
				https://github.com/linux-do/new-api.git
				synced 2025-11-04 13:23:42 +08:00 
			
		
		
		
	feat: 初步重构完成
This commit is contained in:
		@@ -3,6 +3,7 @@ package controller
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -27,7 +28,7 @@ func GetSubscription(c *gin.Context) {
 | 
			
		||||
		expiredTime = 0
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		openAIError := OpenAIError{
 | 
			
		||||
		openAIError := dto.OpenAIError{
 | 
			
		||||
			Message: err.Error(),
 | 
			
		||||
			Type:    "upstream_error",
 | 
			
		||||
		}
 | 
			
		||||
@@ -69,7 +70,7 @@ func GetUsage(c *gin.Context) {
 | 
			
		||||
		quota, err = model.GetUserUsedQuota(userId)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		openAIError := OpenAIError{
 | 
			
		||||
		openAIError := dto.OpenAIError{
 | 
			
		||||
			Message: err.Error(),
 | 
			
		||||
			Type:    "new_api_error",
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,7 @@ import (
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	relaychannel "one-api/relay/channel"
 | 
			
		||||
	"one-api/relay"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/relay/constant"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
@@ -39,7 +39,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
 | 
			
		||||
	c.Set("base_url", channel.GetBaseURL())
 | 
			
		||||
	meta := relaycommon.GenRelayInfo(c)
 | 
			
		||||
	apiType := constant.ChannelType2APIType(channel.Type)
 | 
			
		||||
	adaptor := relaychannel.GetAdaptor(apiType)
 | 
			
		||||
	adaptor := relay.GetAdaptor(apiType)
 | 
			
		||||
	if adaptor == nil {
 | 
			
		||||
		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -10,9 +10,9 @@ import (
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/controller/relay"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	relay2 "one-api/relay"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
@@ -223,7 +223,7 @@ func UpdateMidjourneyTaskBulk() {
 | 
			
		||||
			req = req.WithContext(ctx)
 | 
			
		||||
			req.Header.Set("Content-Type", "application/json")
 | 
			
		||||
			req.Header.Set("mj-api-secret", midjourneyChannel.Key)
 | 
			
		||||
			resp, err := relay.httpClient.Do(req)
 | 
			
		||||
			resp, err := service.GetHttpClient().Do(req)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
 | 
			
		||||
				continue
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package controller
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://platform.openai.com/docs/api-reference/models/list
 | 
			
		||||
@@ -639,7 +640,7 @@ func RetrieveModel(c *gin.Context) {
 | 
			
		||||
	if model, ok := openAIModelsMap[modelId]; ok {
 | 
			
		||||
		c.JSON(200, model)
 | 
			
		||||
	} else {
 | 
			
		||||
		openAIError := OpenAIError{
 | 
			
		||||
		openAIError := dto.OpenAIError{
 | 
			
		||||
			Message: fmt.Sprintf("The model '%s' does not exist", modelId),
 | 
			
		||||
			Type:    "invalid_request_error",
 | 
			
		||||
			Param:   "model",
 | 
			
		||||
 
 | 
			
		||||
@@ -26,7 +26,7 @@ func Relay(c *gin.Context) {
 | 
			
		||||
	case relayconstant.RelayModeAudioTranslation:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case relayconstant.RelayModeAudioTranscription:
 | 
			
		||||
		err = relay.RelayAudioHelper(c, relayMode)
 | 
			
		||||
		err = relay.AudioHelper(c, relayMode)
 | 
			
		||||
	default:
 | 
			
		||||
		err = relay.TextHelper(c)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										13
									
								
								dto/audio.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								dto/audio.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,13 @@
 | 
			
		||||
package dto
 | 
			
		||||
 | 
			
		||||
type TextToSpeechRequest struct {
 | 
			
		||||
	Model          string  `json:"model" binding:"required"`
 | 
			
		||||
	Input          string  `json:"input" binding:"required"`
 | 
			
		||||
	Voice          string  `json:"voice" binding:"required"`
 | 
			
		||||
	Speed          float64 `json:"speed"`
 | 
			
		||||
	ResponseFormat string  `json:"response_format"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AudioResponse struct {
 | 
			
		||||
	Text string `json:"text"`
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										20
									
								
								dto/dalle.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								dto/dalle.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,20 @@
 | 
			
		||||
package dto
 | 
			
		||||
 | 
			
		||||
type ImageRequest struct {
 | 
			
		||||
	Model          string `json:"model"`
 | 
			
		||||
	Prompt         string `json:"prompt" binding:"required"`
 | 
			
		||||
	N              int    `json:"n,omitempty"`
 | 
			
		||||
	Size           string `json:"size,omitempty"`
 | 
			
		||||
	Quality        string `json:"quality,omitempty"`
 | 
			
		||||
	ResponseFormat string `json:"response_format,omitempty"`
 | 
			
		||||
	Style          string `json:"style,omitempty"`
 | 
			
		||||
	User           string `json:"user,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ImageResponse struct {
 | 
			
		||||
	Created int `json:"created"`
 | 
			
		||||
	Data    []struct {
 | 
			
		||||
		Url     string `json:"url"`
 | 
			
		||||
		B64Json string `json:"b64_json"`
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										19
									
								
								dto/midjourney.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								dto/midjourney.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
package dto
 | 
			
		||||
 | 
			
		||||
type MidjourneyRequest struct {
 | 
			
		||||
	Prompt      string   `json:"prompt"`
 | 
			
		||||
	NotifyHook  string   `json:"notifyHook"`
 | 
			
		||||
	Action      string   `json:"action"`
 | 
			
		||||
	Index       int      `json:"index"`
 | 
			
		||||
	State       string   `json:"state"`
 | 
			
		||||
	TaskId      string   `json:"taskId"`
 | 
			
		||||
	Base64Array []string `json:"base64Array"`
 | 
			
		||||
	Content     string   `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MidjourneyResponse struct {
 | 
			
		||||
	Code        int         `json:"code"`
 | 
			
		||||
	Description string      `json:"description"`
 | 
			
		||||
	Properties  interface{} `json:"properties"`
 | 
			
		||||
	Result      string      `json:"result"`
 | 
			
		||||
}
 | 
			
		||||
@@ -33,14 +33,6 @@ type OpenAIEmbeddingResponse struct {
 | 
			
		||||
	Usage  `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ImageResponse struct {
 | 
			
		||||
	Created int `json:"created"`
 | 
			
		||||
	Data    []struct {
 | 
			
		||||
		Url     string `json:"url"`
 | 
			
		||||
		B64Json string `json:"b64_json"`
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatCompletionsStreamResponseChoice struct {
 | 
			
		||||
	Delta struct {
 | 
			
		||||
		Content string `json:"content"`
 | 
			
		||||
@@ -66,21 +58,3 @@ type CompletionsStreamResponse struct {
 | 
			
		||||
		FinishReason string `json:"finish_reason"`
 | 
			
		||||
	} `json:"choices"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MidjourneyRequest struct {
 | 
			
		||||
	Prompt      string   `json:"prompt"`
 | 
			
		||||
	NotifyHook  string   `json:"notifyHook"`
 | 
			
		||||
	Action      string   `json:"action"`
 | 
			
		||||
	Index       int      `json:"index"`
 | 
			
		||||
	State       string   `json:"state"`
 | 
			
		||||
	TaskId      string   `json:"taskId"`
 | 
			
		||||
	Base64Array []string `json:"base64Array"`
 | 
			
		||||
	Content     string   `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MidjourneyResponse struct {
 | 
			
		||||
	Code        int         `json:"code"`
 | 
			
		||||
	Description string      `json:"description"`
 | 
			
		||||
	Properties  interface{} `json:"properties"`
 | 
			
		||||
	Result      string      `json:"result"`
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										4
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								main.go
									
									
									
									
									
								
							@@ -12,8 +12,8 @@ import (
 | 
			
		||||
	"one-api/controller"
 | 
			
		||||
	"one-api/middleware"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"one-api/relay/common"
 | 
			
		||||
	"one-api/router"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
@@ -106,7 +106,7 @@ func main() {
 | 
			
		||||
		common.SysLog("pprof enabled")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	common.InitTokenEncoders()
 | 
			
		||||
	service.InitTokenEncoders()
 | 
			
		||||
 | 
			
		||||
	// Initialize HTTP server
 | 
			
		||||
	server := gin.New()
 | 
			
		||||
 
 | 
			
		||||
@@ -5,17 +5,7 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/relay/channel/ali"
 | 
			
		||||
	"one-api/relay/channel/baidu"
 | 
			
		||||
	"one-api/relay/channel/claude"
 | 
			
		||||
	"one-api/relay/channel/gemini"
 | 
			
		||||
	"one-api/relay/channel/openai"
 | 
			
		||||
	"one-api/relay/channel/palm"
 | 
			
		||||
	"one-api/relay/channel/tencent"
 | 
			
		||||
	"one-api/relay/channel/xunfei"
 | 
			
		||||
	"one-api/relay/channel/zhipu"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/relay/constant"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Adaptor interface {
 | 
			
		||||
@@ -29,29 +19,3 @@ type Adaptor interface {
 | 
			
		||||
	GetModelList() []string
 | 
			
		||||
	GetChannelName() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAdaptor(apiType int) Adaptor {
 | 
			
		||||
	switch apiType {
 | 
			
		||||
	//case constant.APITypeAIProxyLibrary:
 | 
			
		||||
	//	return &aiproxy.Adaptor{}
 | 
			
		||||
	case constant.APITypeAli:
 | 
			
		||||
		return &ali.Adaptor{}
 | 
			
		||||
	case constant.APITypeAnthropic:
 | 
			
		||||
		return &claude.Adaptor{}
 | 
			
		||||
	case constant.APITypeBaidu:
 | 
			
		||||
		return &baidu.Adaptor{}
 | 
			
		||||
	case constant.APITypeGemini:
 | 
			
		||||
		return &gemini.Adaptor{}
 | 
			
		||||
	case constant.APITypeOpenAI:
 | 
			
		||||
		return &openai.Adaptor{}
 | 
			
		||||
	case constant.APITypePaLM:
 | 
			
		||||
		return &palm.Adaptor{}
 | 
			
		||||
	case constant.APITypeTencent:
 | 
			
		||||
		return &tencent.Adaptor{}
 | 
			
		||||
	case constant.APITypeXunfei:
 | 
			
		||||
		return &xunfei.Adaptor{}
 | 
			
		||||
	case constant.APITypeZhipu:
 | 
			
		||||
		return &zhipu.Adaptor{}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,7 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	relaychannel "one-api/relay/channel"
 | 
			
		||||
	"one-api/relay/channel"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/relay/constant"
 | 
			
		||||
)
 | 
			
		||||
@@ -28,7 +28,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	relaychannel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		req.Header.Set("X-DashScope-SSE", "enable")
 | 
			
		||||
@@ -54,7 +54,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return relaychannel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
 
 | 
			
		||||
@@ -6,11 +6,11 @@ import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/relay/common"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetupApiRequestHeader(info *relaycommon.RelayInfo, c *gin.Context, req *http.Request) {
 | 
			
		||||
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
 | 
			
		||||
	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
	if info.IsStream && c.Request.Header.Get("Accept") == "" {
 | 
			
		||||
@@ -18,7 +18,7 @@ func SetupApiRequestHeader(info *relaycommon.RelayInfo, c *gin.Context, req *htt
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DoApiRequest(a Adaptor, c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	fullRequestURL, err := a.GetRequestURL(info)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("get request url failed: %w", err)
 | 
			
		||||
 
 | 
			
		||||
@@ -6,7 +6,7 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	relaychannel "one-api/relay/channel"
 | 
			
		||||
	"one-api/relay/channel"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/relay/constant"
 | 
			
		||||
)
 | 
			
		||||
@@ -46,7 +46,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	relaychannel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -66,7 +66,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return relaychannel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,7 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	relaychannel "one-api/relay/channel"
 | 
			
		||||
	"one-api/relay/channel"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
)
 | 
			
		||||
@@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	relaychannel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("x-api-key", info.ApiKey)
 | 
			
		||||
	anthropicVersion := c.Request.Header.Get("anthropic-version")
 | 
			
		||||
	if anthropicVersion == "" {
 | 
			
		||||
@@ -42,7 +42,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return relaychannel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,7 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	relaychannel "one-api/relay/channel"
 | 
			
		||||
	"one-api/relay/channel"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
)
 | 
			
		||||
@@ -28,7 +28,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	relaychannel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("x-goog-api-key", info.ApiKey)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -41,7 +41,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return relaychannel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
 
 | 
			
		||||
@@ -8,7 +8,7 @@ import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	relaychannel "one-api/relay/channel"
 | 
			
		||||
	"one-api/relay/channel"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -40,7 +40,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	relaychannel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	if info.ChannelType == common.ChannelTypeAzure {
 | 
			
		||||
		req.Header.Set("api-key", info.ApiKey)
 | 
			
		||||
		return nil
 | 
			
		||||
@@ -61,7 +61,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return relaychannel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,7 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	relaychannel "one-api/relay/channel"
 | 
			
		||||
	"one-api/relay/channel"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
)
 | 
			
		||||
@@ -23,7 +23,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	relaychannel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("x-goog-api-key", info.ApiKey)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -36,7 +36,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return relaychannel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,7 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	relaychannel "one-api/relay/channel"
 | 
			
		||||
	"one-api/relay/channel"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -25,7 +25,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	relaychannel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	req.Header.Set("Authorization", a.Sign)
 | 
			
		||||
	req.Header.Set("X-TC-Action", info.UpstreamModelName)
 | 
			
		||||
	return nil
 | 
			
		||||
@@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return relaychannel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
 
 | 
			
		||||
@@ -6,7 +6,7 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	relaychannel "one-api/relay/channel"
 | 
			
		||||
	"one-api/relay/channel"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	relaychannel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,7 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	relaychannel "one-api/relay/channel"
 | 
			
		||||
	"one-api/relay/channel"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -26,7 +26,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
 | 
			
		||||
	relaychannel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	channel.SetupApiRequestHeader(info, c, req)
 | 
			
		||||
	token := getZhipuToken(info.ApiKey)
 | 
			
		||||
	req.Header.Set("Authorization", token)
 | 
			
		||||
	return nil
 | 
			
		||||
@@ -40,7 +40,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return relaychannel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
 
 | 
			
		||||
@@ -56,9 +56,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 | 
			
		||||
	if info.BaseUrl == "" {
 | 
			
		||||
		info.BaseUrl = common.ChannelBaseURLs[channelType]
 | 
			
		||||
	}
 | 
			
		||||
	//if info.ChannelType == common.ChannelTypeAzure {
 | 
			
		||||
	//	info.ApiVersion = GetAzureAPIVersion(c)
 | 
			
		||||
	//}
 | 
			
		||||
	if info.ChannelType == common.ChannelTypeAzure {
 | 
			
		||||
		info.ApiVersion = GetAzureAPIVersion(c)
 | 
			
		||||
	}
 | 
			
		||||
	return info
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -66,3 +66,12 @@ func GetAPIVersion(c *gin.Context) string {
 | 
			
		||||
	}
 | 
			
		||||
	return apiVersion
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAzureAPIVersion(c *gin.Context) string {
 | 
			
		||||
	query := c.Request.URL.Query()
 | 
			
		||||
	apiVersion := query.Get("api-version")
 | 
			
		||||
	if apiVersion == "" {
 | 
			
		||||
		apiVersion = c.GetString("api_version")
 | 
			
		||||
	}
 | 
			
		||||
	return apiVersion
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -10,9 +10,10 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/controller"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	relayconstant "one-api/relay/constant"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
@@ -27,7 +28,7 @@ var availableVoices = []string{
 | 
			
		||||
	"shimmer",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWithStatusCode {
 | 
			
		||||
func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
@@ -35,14 +36,14 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
	startTime := time.Now()
 | 
			
		||||
 | 
			
		||||
	var audioRequest AudioRequest
 | 
			
		||||
	var audioRequest dto.TextToSpeechRequest
 | 
			
		||||
	if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
 | 
			
		||||
		err := common.UnmarshalBodyReusable(c, &audioRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		audioRequest = AudioRequest{
 | 
			
		||||
		audioRequest = dto.TextToSpeechRequest{
 | 
			
		||||
			Model: "whisper-1",
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
@@ -109,10 +110,10 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
 | 
			
		||||
		baseURL = c.GetString("base_url")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fullRequestURL := common.getFullRequestURL(baseURL, requestURL, channelType)
 | 
			
		||||
	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
 | 
			
		||||
	fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
 | 
			
		||||
	if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
 | 
			
		||||
		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
 | 
			
		||||
		apiVersion := common.GetAPIVersion(c)
 | 
			
		||||
		apiVersion := relaycommon.GetAzureAPIVersion(c)
 | 
			
		||||
		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -123,7 +124,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
 | 
			
		||||
	if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
 | 
			
		||||
		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
 | 
			
		||||
		apiKey := c.Request.Header.Get("Authorization")
 | 
			
		||||
		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
			
		||||
@@ -136,7 +137,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
 | 
			
		||||
	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
 | 
			
		||||
	resp, err := controller.httpClient.Do(req)
 | 
			
		||||
	resp, err := service.GetHttpClient().Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
@@ -151,7 +152,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		return common.relayErrorHandler(resp)
 | 
			
		||||
		return relaycommon.RelayErrorHandler(resp)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var audioResponse dto.AudioResponse
 | 
			
		||||
@@ -162,10 +163,10 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
 | 
			
		||||
			quota := 0
 | 
			
		||||
			var promptTokens = 0
 | 
			
		||||
			if strings.HasPrefix(audioRequest.Model, "tts-1") {
 | 
			
		||||
				quota = service.countAudioToken(audioRequest.Input, audioRequest.Model)
 | 
			
		||||
				quota = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
 | 
			
		||||
				promptTokens = quota
 | 
			
		||||
			} else {
 | 
			
		||||
				quota = service.countAudioToken(audioResponse.Text, audioRequest.Model)
 | 
			
		||||
				quota = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
 | 
			
		||||
			}
 | 
			
		||||
			quota = int(float64(quota) * ratio)
 | 
			
		||||
			if ratio != 0 && quota <= 0 {
 | 
			
		||||
 
 | 
			
		||||
@@ -10,15 +10,16 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/controller"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"one-api/relay/common"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	relayconstant "one-api/relay/constant"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
@@ -31,7 +32,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
	if consumeQuota {
 | 
			
		||||
		err := common.UnmarshalBodyReusable(c, &imageRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
			return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -46,29 +47,29 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
	}
 | 
			
		||||
	// Prompt validation
 | 
			
		||||
	if imageRequest.Prompt == "" {
 | 
			
		||||
		return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
 | 
			
		||||
		return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if strings.Contains(imageRequest.Size, "×") {
 | 
			
		||||
		return errorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
 | 
			
		||||
		return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	// Not "256x256", "512x512", or "1024x1024"
 | 
			
		||||
	if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
 | 
			
		||||
		if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
 | 
			
		||||
			return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
 | 
			
		||||
			return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
	} else if imageRequest.Model == "dall-e-3" {
 | 
			
		||||
		if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
 | 
			
		||||
			return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
 | 
			
		||||
			return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
		if imageRequest.N != 1 {
 | 
			
		||||
			return errorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest)
 | 
			
		||||
			return service.OpenAIErrorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// N should between 1 and 10
 | 
			
		||||
	if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
 | 
			
		||||
		return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
 | 
			
		||||
		return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// map model name
 | 
			
		||||
@@ -78,7 +79,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
		modelMap := make(map[string]string)
 | 
			
		||||
		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
			
		||||
			return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		if modelMap[imageRequest.Model] != "" {
 | 
			
		||||
			imageRequest.Model = modelMap[imageRequest.Model]
 | 
			
		||||
@@ -90,10 +91,10 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
	if c.GetString("base_url") != "" {
 | 
			
		||||
		baseURL = c.GetString("base_url")
 | 
			
		||||
	}
 | 
			
		||||
	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
 | 
			
		||||
	if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
 | 
			
		||||
	fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
 | 
			
		||||
	if channelType == common.ChannelTypeAzure && relayMode == relayconstant.RelayModeImagesGenerations {
 | 
			
		||||
		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
 | 
			
		||||
		apiVersion := common.GetAPIVersion(c)
 | 
			
		||||
		apiVersion := relaycommon.GetAPIVersion(c)
 | 
			
		||||
		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
 | 
			
		||||
		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
 | 
			
		||||
	}
 | 
			
		||||
@@ -101,7 +102,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
	if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
 | 
			
		||||
		jsonStr, err := json.Marshal(imageRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
			return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	} else {
 | 
			
		||||
@@ -136,12 +137,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
	quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
 | 
			
		||||
 | 
			
		||||
	if consumeQuota && userQuota-quota < 0 {
 | 
			
		||||
		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 | 
			
		||||
		return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	token := c.Request.Header.Get("Authorization")
 | 
			
		||||
@@ -154,25 +155,25 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
 | 
			
		||||
	resp, err := controller.httpClient.Do(req)
 | 
			
		||||
	resp, err := service.GetHttpClient().Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = req.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = c.Request.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		return relayErrorHandler(resp)
 | 
			
		||||
		return relaycommon.RelayErrorHandler(resp)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var textResponse ImageResponse
 | 
			
		||||
	var textResponse dto.ImageResponse
 | 
			
		||||
	defer func(ctx context.Context) {
 | 
			
		||||
		useTimeSeconds := time.Now().Unix() - startTime.Unix()
 | 
			
		||||
		if consumeQuota {
 | 
			
		||||
@@ -202,15 +203,15 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
		responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
			return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
			return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		err = json.Unmarshal(responseBody, &textResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
			return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
@@ -223,11 +224,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
 | 
			
		||||
	_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -9,8 +9,10 @@ import (
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/controller"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	relayconstant "one-api/relay/constant"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
@@ -105,11 +107,11 @@ func RelayMidjourneyImage(c *gin.Context) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
 | 
			
		||||
func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
 | 
			
		||||
	var midjRequest Midjourney
 | 
			
		||||
	err := common.UnmarshalBodyReusable(c, &midjRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "bind_request_body_failed",
 | 
			
		||||
			Properties:  nil,
 | 
			
		||||
@@ -118,7 +120,7 @@ func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
 | 
			
		||||
	}
 | 
			
		||||
	midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
 | 
			
		||||
	if midjourneyTask == nil {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "midjourney_task_not_found",
 | 
			
		||||
			Properties:  nil,
 | 
			
		||||
@@ -136,7 +138,7 @@ func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
 | 
			
		||||
	midjourneyTask.FailReason = midjRequest.FailReason
 | 
			
		||||
	err = midjourneyTask.Update()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "update_midjourney_task_failed",
 | 
			
		||||
		}
 | 
			
		||||
@@ -168,16 +170,16 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	var err error
 | 
			
		||||
	var respBody []byte
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case RelayModeMidjourneyTaskFetch:
 | 
			
		||||
	case relayconstant.RelayModeMidjourneyTaskFetch:
 | 
			
		||||
		taskId := c.Param("id")
 | 
			
		||||
		originTask := model.GetByMJId(userId, taskId)
 | 
			
		||||
		if originTask == nil {
 | 
			
		||||
			return &MidjourneyResponse{
 | 
			
		||||
			return &dto.MidjourneyResponse{
 | 
			
		||||
				Code:        4,
 | 
			
		||||
				Description: "task_no_found",
 | 
			
		||||
			}
 | 
			
		||||
@@ -185,18 +187,18 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
		midjourneyTask := getMidjourneyTaskModel(c, originTask)
 | 
			
		||||
		respBody, err = json.Marshal(midjourneyTask)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return &MidjourneyResponse{
 | 
			
		||||
			return &dto.MidjourneyResponse{
 | 
			
		||||
				Code:        4,
 | 
			
		||||
				Description: "unmarshal_response_body_failed",
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	case RelayModeMidjourneyTaskFetchByCondition:
 | 
			
		||||
	case relayconstant.RelayModeMidjourneyTaskFetchByCondition:
 | 
			
		||||
		var condition = struct {
 | 
			
		||||
			IDs []string `json:"ids"`
 | 
			
		||||
		}{}
 | 
			
		||||
		err = c.BindJSON(&condition)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return &MidjourneyResponse{
 | 
			
		||||
			return &dto.MidjourneyResponse{
 | 
			
		||||
				Code:        4,
 | 
			
		||||
				Description: "do_request_failed",
 | 
			
		||||
			}
 | 
			
		||||
@@ -214,7 +216,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
		}
 | 
			
		||||
		respBody, err = json.Marshal(tasks)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return &MidjourneyResponse{
 | 
			
		||||
			return &dto.MidjourneyResponse{
 | 
			
		||||
				Code:        4,
 | 
			
		||||
				Description: "unmarshal_response_body_failed",
 | 
			
		||||
			}
 | 
			
		||||
@@ -225,7 +227,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
 | 
			
		||||
	_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "copy_response_body_failed",
 | 
			
		||||
		}
 | 
			
		||||
@@ -245,7 +247,7 @@ const (
 | 
			
		||||
	MJSubmitActionUpscale  = "UPSCALE" // 放大
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
 | 
			
		||||
	imageModel := "midjourney"
 | 
			
		||||
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
@@ -254,60 +256,60 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
	consumeQuota := c.GetBool("consume_quota")
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
	var midjRequest MidjourneyRequest
 | 
			
		||||
	var midjRequest dto.MidjourneyRequest
 | 
			
		||||
	if consumeQuota {
 | 
			
		||||
		err := common.UnmarshalBodyReusable(c, &midjRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return &MidjourneyResponse{
 | 
			
		||||
			return &dto.MidjourneyResponse{
 | 
			
		||||
				Code:        4,
 | 
			
		||||
				Description: "bind_request_body_failed",
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
 | 
			
		||||
	if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
 | 
			
		||||
		if midjRequest.Prompt == "" {
 | 
			
		||||
			return &MidjourneyResponse{
 | 
			
		||||
			return &dto.MidjourneyResponse{
 | 
			
		||||
				Code:        4,
 | 
			
		||||
				Description: "prompt_is_required",
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		midjRequest.Action = "IMAGINE"
 | 
			
		||||
	} else if relayMode == RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
 | 
			
		||||
	} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
 | 
			
		||||
		midjRequest.Action = "DESCRIBE"
 | 
			
		||||
	} else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
 | 
			
		||||
	} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
 | 
			
		||||
		midjRequest.Action = "BLEND"
 | 
			
		||||
	} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
 | 
			
		||||
		mjId := ""
 | 
			
		||||
		if relayMode == RelayModeMidjourneyChange {
 | 
			
		||||
		if relayMode == relayconstant.RelayModeMidjourneyChange {
 | 
			
		||||
			if midjRequest.TaskId == "" {
 | 
			
		||||
				return &MidjourneyResponse{
 | 
			
		||||
				return &dto.MidjourneyResponse{
 | 
			
		||||
					Code:        4,
 | 
			
		||||
					Description: "taskId_is_required",
 | 
			
		||||
				}
 | 
			
		||||
			} else if midjRequest.Action == "" {
 | 
			
		||||
				return &MidjourneyResponse{
 | 
			
		||||
				return &dto.MidjourneyResponse{
 | 
			
		||||
					Code:        4,
 | 
			
		||||
					Description: "action_is_required",
 | 
			
		||||
				}
 | 
			
		||||
			} else if midjRequest.Index == 0 {
 | 
			
		||||
				return &MidjourneyResponse{
 | 
			
		||||
				return &dto.MidjourneyResponse{
 | 
			
		||||
					Code:        4,
 | 
			
		||||
					Description: "index_can_only_be_1_2_3_4",
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			//action = midjRequest.Action
 | 
			
		||||
			mjId = midjRequest.TaskId
 | 
			
		||||
		} else if relayMode == RelayModeMidjourneySimpleChange {
 | 
			
		||||
		} else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
 | 
			
		||||
			if midjRequest.Content == "" {
 | 
			
		||||
				return &MidjourneyResponse{
 | 
			
		||||
				return &dto.MidjourneyResponse{
 | 
			
		||||
					Code:        4,
 | 
			
		||||
					Description: "content_is_required",
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			params := convertSimpleChangeParams(midjRequest.Content)
 | 
			
		||||
			if params == nil {
 | 
			
		||||
				return &MidjourneyResponse{
 | 
			
		||||
				return &dto.MidjourneyResponse{
 | 
			
		||||
					Code:        4,
 | 
			
		||||
					Description: "content_parse_failed",
 | 
			
		||||
				}
 | 
			
		||||
@@ -318,25 +320,25 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
 | 
			
		||||
		originTask := model.GetByMJId(userId, mjId)
 | 
			
		||||
		if originTask == nil {
 | 
			
		||||
			return &MidjourneyResponse{
 | 
			
		||||
			return &dto.MidjourneyResponse{
 | 
			
		||||
				Code:        4,
 | 
			
		||||
				Description: "task_no_found",
 | 
			
		||||
			}
 | 
			
		||||
		} else if originTask.Action == "UPSCALE" {
 | 
			
		||||
			//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
 | 
			
		||||
			return &MidjourneyResponse{
 | 
			
		||||
			return &dto.MidjourneyResponse{
 | 
			
		||||
				Code:        4,
 | 
			
		||||
				Description: "upscale_task_can_not_be_change",
 | 
			
		||||
			}
 | 
			
		||||
		} else if originTask.Status != "SUCCESS" {
 | 
			
		||||
			return &MidjourneyResponse{
 | 
			
		||||
			return &dto.MidjourneyResponse{
 | 
			
		||||
				Code:        4,
 | 
			
		||||
				Description: "task_status_is_not_success",
 | 
			
		||||
			}
 | 
			
		||||
		} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
 | 
			
		||||
			channel, err := model.GetChannelById(originTask.ChannelId, false)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return &MidjourneyResponse{
 | 
			
		||||
				return &dto.MidjourneyResponse{
 | 
			
		||||
					Code:        4,
 | 
			
		||||
					Description: "channel_not_found",
 | 
			
		||||
				}
 | 
			
		||||
@@ -356,7 +358,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
			
		||||
			return &MidjourneyResponse{
 | 
			
		||||
			return &dto.MidjourneyResponse{
 | 
			
		||||
				Code:        4,
 | 
			
		||||
				Description: "unmarshal_model_mapping_failed",
 | 
			
		||||
			}
 | 
			
		||||
@@ -383,7 +385,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
	if isModelMapped {
 | 
			
		||||
		jsonStr, err := json.Marshal(midjRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return &MidjourneyResponse{
 | 
			
		||||
			return &dto.MidjourneyResponse{
 | 
			
		||||
				Code:        4,
 | 
			
		||||
				Description: "marshal_text_request_failed",
 | 
			
		||||
			}
 | 
			
		||||
@@ -407,7 +409,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
	ratio := modelPrice * groupRatio
 | 
			
		||||
	userQuota, err := model.CacheGetUserQuota(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: err.Error(),
 | 
			
		||||
		}
 | 
			
		||||
@@ -415,7 +417,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
	quota := int(ratio * common.QuotaPerUnit)
 | 
			
		||||
 | 
			
		||||
	if consumeQuota && userQuota-quota < 0 {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "quota_not_enough",
 | 
			
		||||
		}
 | 
			
		||||
@@ -423,7 +425,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
 | 
			
		||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "create_request_failed",
 | 
			
		||||
		}
 | 
			
		||||
@@ -442,9 +444,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
	log.Printf("request header: %s", req.Header)
 | 
			
		||||
	log.Printf("request body: %s", midjRequest.Prompt)
 | 
			
		||||
 | 
			
		||||
	resp, err := controller.httpClient.Do(req)
 | 
			
		||||
	resp, err := service.GetHttpClient().Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "do_request_failed",
 | 
			
		||||
		}
 | 
			
		||||
@@ -452,19 +454,19 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
 | 
			
		||||
	err = req.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "close_request_body_failed",
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	err = c.Request.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "close_request_body_failed",
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	var midjResponse MidjourneyResponse
 | 
			
		||||
	var midjResponse dto.MidjourneyResponse
 | 
			
		||||
 | 
			
		||||
	defer func(ctx context.Context) {
 | 
			
		||||
		if consumeQuota {
 | 
			
		||||
@@ -493,14 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "read_response_body_failed",
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "close_response_body_failed",
 | 
			
		||||
		}
 | 
			
		||||
@@ -510,13 +512,13 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
	log.Printf("responseBody: %s", string(responseBody))
 | 
			
		||||
	log.Printf("midjResponse: %v", midjResponse)
 | 
			
		||||
	if resp.StatusCode != 200 {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "unmarshal_response_body_failed",
 | 
			
		||||
		}
 | 
			
		||||
@@ -579,7 +581,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
 | 
			
		||||
	err = midjourneyTask.Insert()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "insert_midjourney_task_failed",
 | 
			
		||||
		}
 | 
			
		||||
@@ -600,14 +602,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 | 
			
		||||
 | 
			
		||||
	_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "copy_response_body_failed",
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &MidjourneyResponse{
 | 
			
		||||
		return &dto.MidjourneyResponse{
 | 
			
		||||
			Code:        4,
 | 
			
		||||
			Description: "close_response_body_failed",
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -11,7 +11,6 @@ import (
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	relaychannel "one-api/relay/channel"
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	relayconstant "one-api/relay/constant"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
@@ -119,7 +118,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
		return openaiErr
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	adaptor := relaychannel.GetAdaptor(relayInfo.ApiType)
 | 
			
		||||
	adaptor := GetAdaptor(relayInfo.ApiType)
 | 
			
		||||
	if adaptor == nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										41
									
								
								relay/relay_adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								relay/relay_adaptor.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,41 @@
 | 
			
		||||
package relay
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"one-api/relay/channel"
 | 
			
		||||
	"one-api/relay/channel/ali"
 | 
			
		||||
	"one-api/relay/channel/baidu"
 | 
			
		||||
	"one-api/relay/channel/claude"
 | 
			
		||||
	"one-api/relay/channel/gemini"
 | 
			
		||||
	"one-api/relay/channel/openai"
 | 
			
		||||
	"one-api/relay/channel/palm"
 | 
			
		||||
	"one-api/relay/channel/tencent"
 | 
			
		||||
	"one-api/relay/channel/xunfei"
 | 
			
		||||
	"one-api/relay/channel/zhipu"
 | 
			
		||||
	"one-api/relay/constant"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetAdaptor(apiType int) channel.Adaptor {
 | 
			
		||||
	switch apiType {
 | 
			
		||||
	//case constant.APITypeAIProxyLibrary:
 | 
			
		||||
	//	return &aiproxy.Adaptor{}
 | 
			
		||||
	case constant.APITypeAli:
 | 
			
		||||
		return &ali.Adaptor{}
 | 
			
		||||
	case constant.APITypeAnthropic:
 | 
			
		||||
		return &claude.Adaptor{}
 | 
			
		||||
	case constant.APITypeBaidu:
 | 
			
		||||
		return &baidu.Adaptor{}
 | 
			
		||||
	case constant.APITypeGemini:
 | 
			
		||||
		return &gemini.Adaptor{}
 | 
			
		||||
	case constant.APITypeOpenAI:
 | 
			
		||||
		return &openai.Adaptor{}
 | 
			
		||||
	case constant.APITypePaLM:
 | 
			
		||||
		return &palm.Adaptor{}
 | 
			
		||||
	case constant.APITypeTencent:
 | 
			
		||||
		return &tencent.Adaptor{}
 | 
			
		||||
	case constant.APITypeXunfei:
 | 
			
		||||
		return &xunfei.Adaptor{}
 | 
			
		||||
	case constant.APITypeZhipu:
 | 
			
		||||
		return &zhipu.Adaptor{}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -201,7 +201,7 @@ func CountTokenInput(input any, model string) int {
 | 
			
		||||
	return 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func countAudioToken(text string, model string) int {
 | 
			
		||||
func CountAudioToken(text string, model string) int {
 | 
			
		||||
	if strings.HasPrefix(model, "tts") {
 | 
			
		||||
		return utf8.RuneCountInString(text)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user