feat: 初步重构完成

This commit is contained in:
1808837298@qq.com 2024-02-29 16:21:25 +08:00
parent 5b18cd6b0a
commit 6013219f5b
30 changed files with 240 additions and 195 deletions

View File

@ -3,6 +3,7 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"one-api/dto"
"one-api/model" "one-api/model"
) )
@ -27,7 +28,7 @@ func GetSubscription(c *gin.Context) {
expiredTime = 0 expiredTime = 0
} }
if err != nil { if err != nil {
openAIError := OpenAIError{ openAIError := dto.OpenAIError{
Message: err.Error(), Message: err.Error(),
Type: "upstream_error", Type: "upstream_error",
} }
@ -69,7 +70,7 @@ func GetUsage(c *gin.Context) {
quota, err = model.GetUserUsedQuota(userId) quota, err = model.GetUserUsedQuota(userId)
} }
if err != nil { if err != nil {
openAIError := OpenAIError{ openAIError := dto.OpenAIError{
Message: err.Error(), Message: err.Error(),
Type: "new_api_error", Type: "new_api_error",
} }

View File

@ -12,7 +12,7 @@ import (
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
relaychannel "one-api/relay/channel" "one-api/relay"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
"one-api/service" "one-api/service"
@ -39,7 +39,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
meta := relaycommon.GenRelayInfo(c) meta := relaycommon.GenRelayInfo(c)
apiType := constant.ChannelType2APIType(channel.Type) apiType := constant.ChannelType2APIType(channel.Type)
adaptor := relaychannel.GetAdaptor(apiType) adaptor := relay.GetAdaptor(apiType)
if adaptor == nil { if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
} }

View File

@ -10,9 +10,9 @@ import (
"log" "log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/controller/relay"
"one-api/model" "one-api/model"
relay2 "one-api/relay" relay2 "one-api/relay"
"one-api/service"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -223,7 +223,7 @@ func UpdateMidjourneyTaskBulk() {
req = req.WithContext(ctx) req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("mj-api-secret", midjourneyChannel.Key) req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := relay.httpClient.Do(req) resp, err := service.GetHttpClient().Do(req)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue continue

View File

@ -3,6 +3,7 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/dto"
) )
// https://platform.openai.com/docs/api-reference/models/list // https://platform.openai.com/docs/api-reference/models/list
@ -639,7 +640,7 @@ func RetrieveModel(c *gin.Context) {
if model, ok := openAIModelsMap[modelId]; ok { if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model) c.JSON(200, model)
} else { } else {
openAIError := OpenAIError{ openAIError := dto.OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId), Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error", Type: "invalid_request_error",
Param: "model", Param: "model",

View File

@ -26,7 +26,7 @@ func Relay(c *gin.Context) {
case relayconstant.RelayModeAudioTranslation: case relayconstant.RelayModeAudioTranslation:
fallthrough fallthrough
case relayconstant.RelayModeAudioTranscription: case relayconstant.RelayModeAudioTranscription:
err = relay.RelayAudioHelper(c, relayMode) err = relay.AudioHelper(c, relayMode)
default: default:
err = relay.TextHelper(c) err = relay.TextHelper(c)
} }

13
dto/audio.go Normal file
View File

@ -0,0 +1,13 @@
package dto
type TextToSpeechRequest struct {
Model string `json:"model" binding:"required"`
Input string `json:"input" binding:"required"`
Voice string `json:"voice" binding:"required"`
Speed float64 `json:"speed"`
ResponseFormat string `json:"response_format"`
}
type AudioResponse struct {
Text string `json:"text"`
}

20
dto/dalle.go Normal file
View File

@ -0,0 +1,20 @@
package dto
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
User string `json:"user,omitempty"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
}
}

19
dto/midjourney.go Normal file
View File

@ -0,0 +1,19 @@
package dto
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
State string `json:"state"`
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
}
type MidjourneyResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Properties interface{} `json:"properties"`
Result string `json:"result"`
}

View File

@ -33,14 +33,6 @@ type OpenAIEmbeddingResponse struct {
Usage `json:"usage"` Usage `json:"usage"`
} }
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
}
}
type ChatCompletionsStreamResponseChoice struct { type ChatCompletionsStreamResponseChoice struct {
Delta struct { Delta struct {
Content string `json:"content"` Content string `json:"content"`
@ -66,21 +58,3 @@ type CompletionsStreamResponse struct {
FinishReason string `json:"finish_reason"` FinishReason string `json:"finish_reason"`
} `json:"choices"` } `json:"choices"`
} }
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
State string `json:"state"`
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
}
type MidjourneyResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Properties interface{} `json:"properties"`
Result string `json:"result"`
}

View File

@ -12,8 +12,8 @@ import (
"one-api/controller" "one-api/controller"
"one-api/middleware" "one-api/middleware"
"one-api/model" "one-api/model"
"one-api/relay/common"
"one-api/router" "one-api/router"
"one-api/service"
"os" "os"
"strconv" "strconv"
@ -106,7 +106,7 @@ func main() {
common.SysLog("pprof enabled") common.SysLog("pprof enabled")
} }
common.InitTokenEncoders() service.InitTokenEncoders()
// Initialize HTTP server // Initialize HTTP server
server := gin.New() server := gin.New()

View File

@ -5,17 +5,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/dto" "one-api/dto"
"one-api/relay/channel/ali"
"one-api/relay/channel/baidu"
"one-api/relay/channel/claude"
"one-api/relay/channel/gemini"
"one-api/relay/channel/openai"
"one-api/relay/channel/palm"
"one-api/relay/channel/tencent"
"one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant"
) )
type Adaptor interface { type Adaptor interface {
@ -29,29 +19,3 @@ type Adaptor interface {
GetModelList() []string GetModelList() []string
GetChannelName() string GetChannelName() string
} }
func GetAdaptor(apiType int) Adaptor {
switch apiType {
//case constant.APITypeAIProxyLibrary:
// return &aiproxy.Adaptor{}
case constant.APITypeAli:
return &ali.Adaptor{}
case constant.APITypeAnthropic:
return &claude.Adaptor{}
case constant.APITypeBaidu:
return &baidu.Adaptor{}
case constant.APITypeGemini:
return &gemini.Adaptor{}
case constant.APITypeOpenAI:
return &openai.Adaptor{}
case constant.APITypePaLM:
return &palm.Adaptor{}
case constant.APITypeTencent:
return &tencent.Adaptor{}
case constant.APITypeXunfei:
return &xunfei.Adaptor{}
case constant.APITypeZhipu:
return &zhipu.Adaptor{}
}
return nil
}

View File

@ -7,7 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/dto" "one-api/dto"
relaychannel "one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
) )
@ -28,7 +28,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey) req.Header.Set("Authorization", "Bearer "+info.ApiKey)
if info.IsStream { if info.IsStream {
req.Header.Set("X-DashScope-SSE", "enable") req.Header.Set("X-DashScope-SSE", "enable")
@ -54,7 +54,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

View File

@ -6,11 +6,11 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
relaycommon "one-api/relay/common" "one-api/relay/common"
"one-api/service" "one-api/service"
) )
func SetupApiRequestHeader(info *relaycommon.RelayInfo, c *gin.Context, req *http.Request) { func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if info.IsStream && c.Request.Header.Get("Accept") == "" { if info.IsStream && c.Request.Header.Get("Accept") == "" {
@ -18,7 +18,7 @@ func SetupApiRequestHeader(info *relaycommon.RelayInfo, c *gin.Context, req *htt
} }
} }
func DoApiRequest(a Adaptor, c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(info) fullRequestURL, err := a.GetRequestURL(info)
if err != nil { if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err) return nil, fmt.Errorf("get request url failed: %w", err)

View File

@ -6,7 +6,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/dto" "one-api/dto"
relaychannel "one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
) )
@ -46,7 +46,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey) req.Header.Set("Authorization", "Bearer "+info.ApiKey)
return nil return nil
} }
@ -66,7 +66,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

View File

@ -7,7 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/dto" "one-api/dto"
relaychannel "one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
) )
@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-api-key", info.ApiKey) req.Header.Set("x-api-key", info.ApiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version") anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" { if anthropicVersion == "" {
@ -42,7 +42,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

View File

@ -7,7 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/dto" "one-api/dto"
relaychannel "one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
) )
@ -28,7 +28,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-goog-api-key", info.ApiKey) req.Header.Set("x-goog-api-key", info.ApiKey)
return nil return nil
} }
@ -41,7 +41,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

View File

@ -8,7 +8,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
relaychannel "one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"strings" "strings"
@ -40,7 +40,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
if info.ChannelType == common.ChannelTypeAzure { if info.ChannelType == common.ChannelTypeAzure {
req.Header.Set("api-key", info.ApiKey) req.Header.Set("api-key", info.ApiKey)
return nil return nil
@ -61,7 +61,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

View File

@ -7,7 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/dto" "one-api/dto"
relaychannel "one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
) )
@ -23,7 +23,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-goog-api-key", info.ApiKey) req.Header.Set("x-goog-api-key", info.ApiKey)
return nil return nil
} }
@ -36,7 +36,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

View File

@ -7,7 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/dto" "one-api/dto"
relaychannel "one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"strings" "strings"
@ -25,7 +25,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", a.Sign) req.Header.Set("Authorization", a.Sign)
req.Header.Set("X-TC-Action", info.UpstreamModelName) req.Header.Set("X-TC-Action", info.UpstreamModelName)
return nil return nil
@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

View File

@ -6,7 +6,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/dto" "one-api/dto"
relaychannel "one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"strings" "strings"
@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
return nil return nil
} }

View File

@ -7,7 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/dto" "one-api/dto"
relaychannel "one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
) )
@ -26,7 +26,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
relaychannel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
token := getZhipuToken(info.ApiKey) token := getZhipuToken(info.ApiKey)
req.Header.Set("Authorization", token) req.Header.Set("Authorization", token)
return nil return nil
@ -40,7 +40,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return relaychannel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {

View File

@ -56,9 +56,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if info.BaseUrl == "" { if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType] info.BaseUrl = common.ChannelBaseURLs[channelType]
} }
//if info.ChannelType == common.ChannelTypeAzure { if info.ChannelType == common.ChannelTypeAzure {
// info.ApiVersion = GetAzureAPIVersion(c) info.ApiVersion = GetAzureAPIVersion(c)
//} }
return info return info
} }

View File

@ -66,3 +66,12 @@ func GetAPIVersion(c *gin.Context) string {
} }
return apiVersion return apiVersion
} }
func GetAzureAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
return apiVersion
}

View File

@ -10,9 +10,10 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/controller"
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"strings" "strings"
"time" "time"
@ -27,7 +28,7 @@ var availableVoices = []string{
"shimmer", "shimmer",
} }
func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWithStatusCode { func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
@ -35,14 +36,14 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
group := c.GetString("group") group := c.GetString("group")
startTime := time.Now() startTime := time.Now()
var audioRequest AudioRequest var audioRequest dto.TextToSpeechRequest
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err := common.UnmarshalBodyReusable(c, &audioRequest) err := common.UnmarshalBodyReusable(c, &audioRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
} }
} else { } else {
audioRequest = AudioRequest{ audioRequest = dto.TextToSpeechRequest{
Model: "whisper-1", Model: "whisper-1",
} }
} }
@ -109,10 +110,10 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
baseURL = c.GetString("base_url") baseURL = c.GetString("base_url")
} }
fullRequestURL := common.getFullRequestURL(baseURL, requestURL, channelType) fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiVersion := common.GetAPIVersion(c) apiVersion := relaycommon.GetAzureAPIVersion(c)
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
} }
@ -123,7 +124,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
} }
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiKey := c.Request.Header.Get("Authorization") apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ") apiKey = strings.TrimPrefix(apiKey, "Bearer ")
@ -136,7 +137,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := controller.httpClient.Do(req) resp, err := service.GetHttpClient().Do(req)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
@ -151,7 +152,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return common.relayErrorHandler(resp) return relaycommon.RelayErrorHandler(resp)
} }
var audioResponse dto.AudioResponse var audioResponse dto.AudioResponse
@ -162,10 +163,10 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
quota := 0 quota := 0
var promptTokens = 0 var promptTokens = 0
if strings.HasPrefix(audioRequest.Model, "tts-1") { if strings.HasPrefix(audioRequest.Model, "tts-1") {
quota = service.countAudioToken(audioRequest.Input, audioRequest.Model) quota = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
promptTokens = quota promptTokens = quota
} else { } else {
quota = service.countAudioToken(audioResponse.Text, audioRequest.Model) quota = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
} }
quota = int(float64(quota) * ratio) quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 { if ratio != 0 && quota <= 0 {

View File

@ -10,15 +10,16 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/controller"
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
"one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"strings" "strings"
"time" "time"
) )
func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
@ -31,7 +32,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if consumeQuota { if consumeQuota {
err := common.UnmarshalBodyReusable(c, &imageRequest) err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil { if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
} }
} }
@ -46,29 +47,29 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
} }
// Prompt validation // Prompt validation
if imageRequest.Prompt == "" { if imageRequest.Prompt == "" {
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
} }
if strings.Contains(imageRequest.Size, "×") { if strings.Contains(imageRequest.Size, "×") {
return errorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest) return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
} }
// Not "256x256", "512x512", or "1024x1024" // Not "256x256", "512x512", or "1024x1024"
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest) return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
} }
} else if imageRequest.Model == "dall-e-3" { } else if imageRequest.Model == "dall-e-3" {
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest) return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
} }
if imageRequest.N != 1 { if imageRequest.N != 1 {
return errorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest) return service.OpenAIErrorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest)
} }
} }
// N should between 1 and 10 // N should between 1 and 10
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
} }
// map model name // map model name
@ -78,7 +79,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
modelMap := make(map[string]string) modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap) err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
} }
if modelMap[imageRequest.Model] != "" { if modelMap[imageRequest.Model] != "" {
imageRequest.Model = modelMap[imageRequest.Model] imageRequest.Model = modelMap[imageRequest.Model]
@ -90,10 +91,10 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if c.GetString("base_url") != "" { if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url") baseURL = c.GetString("base_url")
} }
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations { if channelType == common.ChannelTypeAzure && relayMode == relayconstant.RelayModeImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := common.GetAPIVersion(c) apiVersion := relaycommon.GetAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
} }
@ -101,7 +102,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest) jsonStr, err := json.Marshal(imageRequest)
if err != nil { if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
} }
requestBody = bytes.NewBuffer(jsonStr) requestBody = bytes.NewBuffer(jsonStr)
} else { } else {
@ -136,12 +137,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 { if consumeQuota && userQuota-quota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil { if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
} }
token := c.Request.Header.Get("Authorization") token := c.Request.Header.Get("Authorization")
@ -154,25 +155,25 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := controller.httpClient.Do(req) resp, err := service.GetHttpClient().Do(req)
if err != nil { if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
err = req.Body.Close() err = req.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
} }
err = c.Request.Body.Close() err = c.Request.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return relayErrorHandler(resp) return relaycommon.RelayErrorHandler(resp)
} }
var textResponse ImageResponse var textResponse dto.ImageResponse
defer func(ctx context.Context) { defer func(ctx context.Context) {
useTimeSeconds := time.Now().Unix() - startTime.Unix() useTimeSeconds := time.Now().Unix() - startTime.Unix()
if consumeQuota { if consumeQuota {
@ -202,15 +203,15 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
} }
err = json.Unmarshal(responseBody, &textResponse) err = json.Unmarshal(responseBody, &textResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
} }
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
@ -223,11 +224,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
_, err = io.Copy(c.Writer, resp.Body) _, err = io.Copy(c.Writer, resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
} }
return nil return nil
} }

View File

@ -9,8 +9,10 @@ import (
"log" "log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/controller" "one-api/dto"
"one-api/model" "one-api/model"
relayconstant "one-api/relay/constant"
"one-api/service"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -105,11 +107,11 @@ func RelayMidjourneyImage(c *gin.Context) {
return return
} }
func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse { func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
var midjRequest Midjourney var midjRequest Midjourney
err := common.UnmarshalBodyReusable(c, &midjRequest) err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "bind_request_body_failed", Description: "bind_request_body_failed",
Properties: nil, Properties: nil,
@ -118,7 +120,7 @@ func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
} }
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId) midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
if midjourneyTask == nil { if midjourneyTask == nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "midjourney_task_not_found", Description: "midjourney_task_not_found",
Properties: nil, Properties: nil,
@ -136,7 +138,7 @@ func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
midjourneyTask.FailReason = midjRequest.FailReason midjourneyTask.FailReason = midjRequest.FailReason
err = midjourneyTask.Update() err = midjourneyTask.Update()
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "update_midjourney_task_failed", Description: "update_midjourney_task_failed",
} }
@ -168,16 +170,16 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo
return return
} }
func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse { func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
userId := c.GetInt("id") userId := c.GetInt("id")
var err error var err error
var respBody []byte var respBody []byte
switch relayMode { switch relayMode {
case RelayModeMidjourneyTaskFetch: case relayconstant.RelayModeMidjourneyTaskFetch:
taskId := c.Param("id") taskId := c.Param("id")
originTask := model.GetByMJId(userId, taskId) originTask := model.GetByMJId(userId, taskId)
if originTask == nil { if originTask == nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "task_no_found", Description: "task_no_found",
} }
@ -185,18 +187,18 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
midjourneyTask := getMidjourneyTaskModel(c, originTask) midjourneyTask := getMidjourneyTaskModel(c, originTask)
respBody, err = json.Marshal(midjourneyTask) respBody, err = json.Marshal(midjourneyTask)
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "unmarshal_response_body_failed", Description: "unmarshal_response_body_failed",
} }
} }
case RelayModeMidjourneyTaskFetchByCondition: case relayconstant.RelayModeMidjourneyTaskFetchByCondition:
var condition = struct { var condition = struct {
IDs []string `json:"ids"` IDs []string `json:"ids"`
}{} }{}
err = c.BindJSON(&condition) err = c.BindJSON(&condition)
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "do_request_failed", Description: "do_request_failed",
} }
@ -214,7 +216,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
} }
respBody, err = json.Marshal(tasks) respBody, err = json.Marshal(tasks)
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "unmarshal_response_body_failed", Description: "unmarshal_response_body_failed",
} }
@ -225,7 +227,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "copy_response_body_failed", Description: "copy_response_body_failed",
} }
@ -245,7 +247,7 @@ const (
MJSubmitActionUpscale = "UPSCALE" // 放大 MJSubmitActionUpscale = "UPSCALE" // 放大
) )
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
imageModel := "midjourney" imageModel := "midjourney"
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
@ -254,60 +256,60 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
consumeQuota := c.GetBool("consume_quota") consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group") group := c.GetString("group")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
var midjRequest MidjourneyRequest var midjRequest dto.MidjourneyRequest
if consumeQuota { if consumeQuota {
err := common.UnmarshalBodyReusable(c, &midjRequest) err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "bind_request_body_failed", Description: "bind_request_body_failed",
} }
} }
} }
if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" { if midjRequest.Prompt == "" {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "prompt_is_required", Description: "prompt_is_required",
} }
} }
midjRequest.Action = "IMAGINE" midjRequest.Action = "IMAGINE"
} else if relayMode == RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 } else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
midjRequest.Action = "DESCRIBE" midjRequest.Action = "DESCRIBE"
} else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = "BLEND" midjRequest.Action = "BLEND"
} else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果 } else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果
mjId := "" mjId := ""
if relayMode == RelayModeMidjourneyChange { if relayMode == relayconstant.RelayModeMidjourneyChange {
if midjRequest.TaskId == "" { if midjRequest.TaskId == "" {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "taskId_is_required", Description: "taskId_is_required",
} }
} else if midjRequest.Action == "" { } else if midjRequest.Action == "" {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "action_is_required", Description: "action_is_required",
} }
} else if midjRequest.Index == 0 { } else if midjRequest.Index == 0 {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "index_can_only_be_1_2_3_4", Description: "index_can_only_be_1_2_3_4",
} }
} }
//action = midjRequest.Action //action = midjRequest.Action
mjId = midjRequest.TaskId mjId = midjRequest.TaskId
} else if relayMode == RelayModeMidjourneySimpleChange { } else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
if midjRequest.Content == "" { if midjRequest.Content == "" {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "content_is_required", Description: "content_is_required",
} }
} }
params := convertSimpleChangeParams(midjRequest.Content) params := convertSimpleChangeParams(midjRequest.Content)
if params == nil { if params == nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "content_parse_failed", Description: "content_parse_failed",
} }
@ -318,25 +320,25 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
originTask := model.GetByMJId(userId, mjId) originTask := model.GetByMJId(userId, mjId)
if originTask == nil { if originTask == nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "task_no_found", Description: "task_no_found",
} }
} else if originTask.Action == "UPSCALE" { } else if originTask.Action == "UPSCALE" {
//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest). //return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "upscale_task_can_not_be_change", Description: "upscale_task_can_not_be_change",
} }
} else if originTask.Status != "SUCCESS" { } else if originTask.Status != "SUCCESS" {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "task_status_is_not_success", Description: "task_status_is_not_success",
} }
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理 } else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
channel, err := model.GetChannelById(originTask.ChannelId, false) channel, err := model.GetChannelById(originTask.ChannelId, false)
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "channel_not_found", Description: "channel_not_found",
} }
@ -356,7 +358,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
err := json.Unmarshal([]byte(modelMapping), &modelMap) err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil { if err != nil {
//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "unmarshal_model_mapping_failed", Description: "unmarshal_model_mapping_failed",
} }
@ -383,7 +385,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
if isModelMapped { if isModelMapped {
jsonStr, err := json.Marshal(midjRequest) jsonStr, err := json.Marshal(midjRequest)
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "marshal_text_request_failed", Description: "marshal_text_request_failed",
} }
@ -407,7 +409,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
ratio := modelPrice * groupRatio ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.CacheGetUserQuota(userId)
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: err.Error(), Description: err.Error(),
} }
@ -415,7 +417,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
quota := int(ratio * common.QuotaPerUnit) quota := int(ratio * common.QuotaPerUnit)
if consumeQuota && userQuota-quota < 0 { if consumeQuota && userQuota-quota < 0 {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "quota_not_enough", Description: "quota_not_enough",
} }
@ -423,7 +425,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "create_request_failed", Description: "create_request_failed",
} }
@ -442,9 +444,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
log.Printf("request header: %s", req.Header) log.Printf("request header: %s", req.Header)
log.Printf("request body: %s", midjRequest.Prompt) log.Printf("request body: %s", midjRequest.Prompt)
resp, err := controller.httpClient.Do(req) resp, err := service.GetHttpClient().Do(req)
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "do_request_failed", Description: "do_request_failed",
} }
@ -452,19 +454,19 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
err = req.Body.Close() err = req.Body.Close()
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "close_request_body_failed", Description: "close_request_body_failed",
} }
} }
err = c.Request.Body.Close() err = c.Request.Body.Close()
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "close_request_body_failed", Description: "close_request_body_failed",
} }
} }
var midjResponse MidjourneyResponse var midjResponse dto.MidjourneyResponse
defer func(ctx context.Context) { defer func(ctx context.Context) {
if consumeQuota { if consumeQuota {
@ -493,14 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "read_response_body_failed", Description: "read_response_body_failed",
} }
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "close_response_body_failed", Description: "close_response_body_failed",
} }
@ -510,13 +512,13 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
log.Printf("responseBody: %s", string(responseBody)) log.Printf("responseBody: %s", string(responseBody))
log.Printf("midjResponse: %v", midjResponse) log.Printf("midjResponse: %v", midjResponse)
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode), Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
} }
} }
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "unmarshal_response_body_failed", Description: "unmarshal_response_body_failed",
} }
@ -579,7 +581,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
err = midjourneyTask.Insert() err = midjourneyTask.Insert()
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "insert_midjourney_task_failed", Description: "insert_midjourney_task_failed",
} }
@ -600,14 +602,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
_, err = io.Copy(c.Writer, resp.Body) _, err = io.Copy(c.Writer, resp.Body)
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "copy_response_body_failed", Description: "copy_response_body_failed",
} }
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return &MidjourneyResponse{ return &dto.MidjourneyResponse{
Code: 4, Code: 4,
Description: "close_response_body_failed", Description: "close_response_body_failed",
} }

View File

@ -11,7 +11,6 @@ import (
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
relaychannel "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
@ -119,7 +118,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return openaiErr return openaiErr
} }
adaptor := relaychannel.GetAdaptor(relayInfo.ApiType) adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil { if adaptor == nil {
return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
} }

41
relay/relay_adaptor.go Normal file
View File

@ -0,0 +1,41 @@
package relay
import (
"one-api/relay/channel"
"one-api/relay/channel/ali"
"one-api/relay/channel/baidu"
"one-api/relay/channel/claude"
"one-api/relay/channel/gemini"
"one-api/relay/channel/openai"
"one-api/relay/channel/palm"
"one-api/relay/channel/tencent"
"one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu"
"one-api/relay/constant"
)
func GetAdaptor(apiType int) channel.Adaptor {
switch apiType {
//case constant.APITypeAIProxyLibrary:
// return &aiproxy.Adaptor{}
case constant.APITypeAli:
return &ali.Adaptor{}
case constant.APITypeAnthropic:
return &claude.Adaptor{}
case constant.APITypeBaidu:
return &baidu.Adaptor{}
case constant.APITypeGemini:
return &gemini.Adaptor{}
case constant.APITypeOpenAI:
return &openai.Adaptor{}
case constant.APITypePaLM:
return &palm.Adaptor{}
case constant.APITypeTencent:
return &tencent.Adaptor{}
case constant.APITypeXunfei:
return &xunfei.Adaptor{}
case constant.APITypeZhipu:
return &zhipu.Adaptor{}
}
return nil
}

View File

@ -201,7 +201,7 @@ func CountTokenInput(input any, model string) int {
return 0 return 0
} }
func countAudioToken(text string, model string) int { func CountAudioToken(text string, model string) int {
if strings.HasPrefix(model, "tts") { if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text) return utf8.RuneCountInString(text)
} else { } else {