mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-18 00:16:37 +08:00
feat: 初步重构完成
This commit is contained in:
parent
5b18cd6b0a
commit
6013219f5b
@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -27,7 +28,7 @@ func GetSubscription(c *gin.Context) {
|
|||||||
expiredTime = 0
|
expiredTime = 0
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openAIError := OpenAIError{
|
openAIError := dto.OpenAIError{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
Type: "upstream_error",
|
Type: "upstream_error",
|
||||||
}
|
}
|
||||||
@ -69,7 +70,7 @@ func GetUsage(c *gin.Context) {
|
|||||||
quota, err = model.GetUserUsedQuota(userId)
|
quota, err = model.GetUserUsedQuota(userId)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openAIError := OpenAIError{
|
openAIError := dto.OpenAIError{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
Type: "new_api_error",
|
Type: "new_api_error",
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relaychannel "one-api/relay/channel"
|
"one-api/relay"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@ -39,7 +39,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
c.Set("base_url", channel.GetBaseURL())
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
meta := relaycommon.GenRelayInfo(c)
|
meta := relaycommon.GenRelayInfo(c)
|
||||||
apiType := constant.ChannelType2APIType(channel.Type)
|
apiType := constant.ChannelType2APIType(channel.Type)
|
||||||
adaptor := relaychannel.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||||
}
|
}
|
||||||
|
@ -10,9 +10,9 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/controller/relay"
|
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relay2 "one-api/relay"
|
relay2 "one-api/relay"
|
||||||
|
"one-api/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -223,7 +223,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||||
resp, err := relay.httpClient.Do(req)
|
resp, err := service.GetHttpClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||||
continue
|
continue
|
||||||
|
@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/dto"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/models/list
|
// https://platform.openai.com/docs/api-reference/models/list
|
||||||
@ -639,7 +640,7 @@ func RetrieveModel(c *gin.Context) {
|
|||||||
if model, ok := openAIModelsMap[modelId]; ok {
|
if model, ok := openAIModelsMap[modelId]; ok {
|
||||||
c.JSON(200, model)
|
c.JSON(200, model)
|
||||||
} else {
|
} else {
|
||||||
openAIError := OpenAIError{
|
openAIError := dto.OpenAIError{
|
||||||
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
||||||
Type: "invalid_request_error",
|
Type: "invalid_request_error",
|
||||||
Param: "model",
|
Param: "model",
|
||||||
|
@ -26,7 +26,7 @@ func Relay(c *gin.Context) {
|
|||||||
case relayconstant.RelayModeAudioTranslation:
|
case relayconstant.RelayModeAudioTranslation:
|
||||||
fallthrough
|
fallthrough
|
||||||
case relayconstant.RelayModeAudioTranscription:
|
case relayconstant.RelayModeAudioTranscription:
|
||||||
err = relay.RelayAudioHelper(c, relayMode)
|
err = relay.AudioHelper(c, relayMode)
|
||||||
default:
|
default:
|
||||||
err = relay.TextHelper(c)
|
err = relay.TextHelper(c)
|
||||||
}
|
}
|
||||||
|
13
dto/audio.go
Normal file
13
dto/audio.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type TextToSpeechRequest struct {
|
||||||
|
Model string `json:"model" binding:"required"`
|
||||||
|
Input string `json:"input" binding:"required"`
|
||||||
|
Voice string `json:"voice" binding:"required"`
|
||||||
|
Speed float64 `json:"speed"`
|
||||||
|
ResponseFormat string `json:"response_format"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AudioResponse struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
20
dto/dalle.go
Normal file
20
dto/dalle.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type ImageRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt" binding:"required"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
Quality string `json:"quality,omitempty"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
Style string `json:"style,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageResponse struct {
|
||||||
|
Created int `json:"created"`
|
||||||
|
Data []struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
B64Json string `json:"b64_json"`
|
||||||
|
}
|
||||||
|
}
|
19
dto/midjourney.go
Normal file
19
dto/midjourney.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type MidjourneyRequest struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
NotifyHook string `json:"notifyHook"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
State string `json:"state"`
|
||||||
|
TaskId string `json:"taskId"`
|
||||||
|
Base64Array []string `json:"base64Array"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MidjourneyResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Properties interface{} `json:"properties"`
|
||||||
|
Result string `json:"result"`
|
||||||
|
}
|
@ -33,14 +33,6 @@ type OpenAIEmbeddingResponse struct {
|
|||||||
Usage `json:"usage"`
|
Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ImageResponse struct {
|
|
||||||
Created int `json:"created"`
|
|
||||||
Data []struct {
|
|
||||||
Url string `json:"url"`
|
|
||||||
B64Json string `json:"b64_json"`
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionsStreamResponseChoice struct {
|
type ChatCompletionsStreamResponseChoice struct {
|
||||||
Delta struct {
|
Delta struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
@ -66,21 +58,3 @@ type CompletionsStreamResponse struct {
|
|||||||
FinishReason string `json:"finish_reason"`
|
FinishReason string `json:"finish_reason"`
|
||||||
} `json:"choices"`
|
} `json:"choices"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type MidjourneyRequest struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
NotifyHook string `json:"notifyHook"`
|
|
||||||
Action string `json:"action"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
State string `json:"state"`
|
|
||||||
TaskId string `json:"taskId"`
|
|
||||||
Base64Array []string `json:"base64Array"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type MidjourneyResponse struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Properties interface{} `json:"properties"`
|
|
||||||
Result string `json:"result"`
|
|
||||||
}
|
|
4
main.go
4
main.go
@ -12,8 +12,8 @@ import (
|
|||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay/common"
|
|
||||||
"one-api/router"
|
"one-api/router"
|
||||||
|
"one-api/service"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
@ -106,7 +106,7 @@ func main() {
|
|||||||
common.SysLog("pprof enabled")
|
common.SysLog("pprof enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
common.InitTokenEncoders()
|
service.InitTokenEncoders()
|
||||||
|
|
||||||
// Initialize HTTP server
|
// Initialize HTTP server
|
||||||
server := gin.New()
|
server := gin.New()
|
||||||
|
@ -5,17 +5,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel/ali"
|
|
||||||
"one-api/relay/channel/baidu"
|
|
||||||
"one-api/relay/channel/claude"
|
|
||||||
"one-api/relay/channel/gemini"
|
|
||||||
"one-api/relay/channel/openai"
|
|
||||||
"one-api/relay/channel/palm"
|
|
||||||
"one-api/relay/channel/tencent"
|
|
||||||
"one-api/relay/channel/xunfei"
|
|
||||||
"one-api/relay/channel/zhipu"
|
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor interface {
|
type Adaptor interface {
|
||||||
@ -29,29 +19,3 @@ type Adaptor interface {
|
|||||||
GetModelList() []string
|
GetModelList() []string
|
||||||
GetChannelName() string
|
GetChannelName() string
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAdaptor(apiType int) Adaptor {
|
|
||||||
switch apiType {
|
|
||||||
//case constant.APITypeAIProxyLibrary:
|
|
||||||
// return &aiproxy.Adaptor{}
|
|
||||||
case constant.APITypeAli:
|
|
||||||
return &ali.Adaptor{}
|
|
||||||
case constant.APITypeAnthropic:
|
|
||||||
return &claude.Adaptor{}
|
|
||||||
case constant.APITypeBaidu:
|
|
||||||
return &baidu.Adaptor{}
|
|
||||||
case constant.APITypeGemini:
|
|
||||||
return &gemini.Adaptor{}
|
|
||||||
case constant.APITypeOpenAI:
|
|
||||||
return &openai.Adaptor{}
|
|
||||||
case constant.APITypePaLM:
|
|
||||||
return &palm.Adaptor{}
|
|
||||||
case constant.APITypeTencent:
|
|
||||||
return &tencent.Adaptor{}
|
|
||||||
case constant.APITypeXunfei:
|
|
||||||
return &xunfei.Adaptor{}
|
|
||||||
case constant.APITypeZhipu:
|
|
||||||
return &zhipu.Adaptor{}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaychannel "one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
)
|
)
|
||||||
@ -28,7 +28,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
req.Header.Set("X-DashScope-SSE", "enable")
|
req.Header.Set("X-DashScope-SSE", "enable")
|
||||||
@ -54,7 +54,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
@ -6,11 +6,11 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
relaycommon "one-api/relay/common"
|
"one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetupApiRequestHeader(info *relaycommon.RelayInfo, c *gin.Context, req *http.Request) {
|
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
if info.IsStream && c.Request.Header.Get("Accept") == "" {
|
if info.IsStream && c.Request.Header.Get("Accept") == "" {
|
||||||
@ -18,7 +18,7 @@ func SetupApiRequestHeader(info *relaycommon.RelayInfo, c *gin.Context, req *htt
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func DoApiRequest(a Adaptor, c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
fullRequestURL, err := a.GetRequestURL(info)
|
fullRequestURL, err := a.GetRequestURL(info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get request url failed: %w", err)
|
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaychannel "one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
)
|
)
|
||||||
@ -46,7 +46,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -66,7 +66,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaychannel "one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
)
|
)
|
||||||
@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("x-api-key", info.ApiKey)
|
req.Header.Set("x-api-key", info.ApiKey)
|
||||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||||
if anthropicVersion == "" {
|
if anthropicVersion == "" {
|
||||||
@ -42,7 +42,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaychannel "one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
)
|
)
|
||||||
@ -28,7 +28,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("x-goog-api-key", info.ApiKey)
|
req.Header.Set("x-goog-api-key", info.ApiKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -41,7 +41,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaychannel "one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
@ -40,7 +40,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
if info.ChannelType == common.ChannelTypeAzure {
|
if info.ChannelType == common.ChannelTypeAzure {
|
||||||
req.Header.Set("api-key", info.ApiKey)
|
req.Header.Set("api-key", info.ApiKey)
|
||||||
return nil
|
return nil
|
||||||
@ -61,7 +61,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaychannel "one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
)
|
)
|
||||||
@ -23,7 +23,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("x-goog-api-key", info.ApiKey)
|
req.Header.Set("x-goog-api-key", info.ApiKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -36,7 +36,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaychannel "one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
@ -25,7 +25,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
req.Header.Set("Authorization", a.Sign)
|
req.Header.Set("Authorization", a.Sign)
|
||||||
req.Header.Set("X-TC-Action", info.UpstreamModelName)
|
req.Header.Set("X-TC-Action", info.UpstreamModelName)
|
||||||
return nil
|
return nil
|
||||||
@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaychannel "one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaychannel "one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
relaychannel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
token := getZhipuToken(info.ApiKey)
|
token := getZhipuToken(info.ApiKey)
|
||||||
req.Header.Set("Authorization", token)
|
req.Header.Set("Authorization", token)
|
||||||
return nil
|
return nil
|
||||||
@ -40,7 +40,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return relaychannel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
@ -56,9 +56,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
if info.BaseUrl == "" {
|
if info.BaseUrl == "" {
|
||||||
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
||||||
}
|
}
|
||||||
//if info.ChannelType == common.ChannelTypeAzure {
|
if info.ChannelType == common.ChannelTypeAzure {
|
||||||
// info.ApiVersion = GetAzureAPIVersion(c)
|
info.ApiVersion = GetAzureAPIVersion(c)
|
||||||
//}
|
}
|
||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,3 +66,12 @@ func GetAPIVersion(c *gin.Context) string {
|
|||||||
}
|
}
|
||||||
return apiVersion
|
return apiVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetAzureAPIVersion(c *gin.Context) string {
|
||||||
|
query := c.Request.URL.Query()
|
||||||
|
apiVersion := query.Get("api-version")
|
||||||
|
if apiVersion == "" {
|
||||||
|
apiVersion = c.GetString("api_version")
|
||||||
|
}
|
||||||
|
return apiVersion
|
||||||
|
}
|
||||||
|
@ -10,9 +10,10 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/controller"
|
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -27,7 +28,7 @@ var availableVoices = []string{
|
|||||||
"shimmer",
|
"shimmer",
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWithStatusCode {
|
func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
@ -35,14 +36,14 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
|
|||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
var audioRequest AudioRequest
|
var audioRequest dto.TextToSpeechRequest
|
||||||
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||||
err := common.UnmarshalBodyReusable(c, &audioRequest)
|
err := common.UnmarshalBodyReusable(c, &audioRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
audioRequest = AudioRequest{
|
audioRequest = dto.TextToSpeechRequest{
|
||||||
Model: "whisper-1",
|
Model: "whisper-1",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -109,10 +110,10 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
|
|||||||
baseURL = c.GetString("base_url")
|
baseURL = c.GetString("base_url")
|
||||||
}
|
}
|
||||||
|
|
||||||
fullRequestURL := common.getFullRequestURL(baseURL, requestURL, channelType)
|
fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
|
||||||
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||||
apiVersion := common.GetAPIVersion(c)
|
apiVersion := relaycommon.GetAzureAPIVersion(c)
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
|
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,7 +124,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
|
|||||||
return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
apiKey := c.Request.Header.Get("Authorization")
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||||
@ -136,7 +137,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
|
|||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
|
|
||||||
resp, err := controller.httpClient.Do(req)
|
resp, err := service.GetHttpClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@ -151,7 +152,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
|
|||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return common.relayErrorHandler(resp)
|
return relaycommon.RelayErrorHandler(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
var audioResponse dto.AudioResponse
|
var audioResponse dto.AudioResponse
|
||||||
@ -162,10 +163,10 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWith
|
|||||||
quota := 0
|
quota := 0
|
||||||
var promptTokens = 0
|
var promptTokens = 0
|
||||||
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
if strings.HasPrefix(audioRequest.Model, "tts-1") {
|
||||||
quota = service.countAudioToken(audioRequest.Input, audioRequest.Model)
|
quota = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
|
||||||
promptTokens = quota
|
promptTokens = quota
|
||||||
} else {
|
} else {
|
||||||
quota = service.countAudioToken(audioResponse.Text, audioRequest.Model)
|
quota = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
|
||||||
}
|
}
|
||||||
quota = int(float64(quota) * ratio)
|
quota = int(float64(quota) * ratio)
|
||||||
if ratio != 0 && quota <= 0 {
|
if ratio != 0 && quota <= 0 {
|
||||||
|
@ -10,15 +10,16 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/controller"
|
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
@ -31,7 +32,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if consumeQuota {
|
if consumeQuota {
|
||||||
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,29 +47,29 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
}
|
}
|
||||||
// Prompt validation
|
// Prompt validation
|
||||||
if imageRequest.Prompt == "" {
|
if imageRequest.Prompt == "" {
|
||||||
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(imageRequest.Size, "×") {
|
if strings.Contains(imageRequest.Size, "×") {
|
||||||
return errorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
// Not "256x256", "512x512", or "1024x1024"
|
// Not "256x256", "512x512", or "1024x1024"
|
||||||
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
|
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
|
||||||
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
|
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
|
||||||
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
} else if imageRequest.Model == "dall-e-3" {
|
} else if imageRequest.Model == "dall-e-3" {
|
||||||
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
|
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
|
||||||
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
if imageRequest.N != 1 {
|
if imageRequest.N != 1 {
|
||||||
return errorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// N should between 1 and 10
|
// N should between 1 and 10
|
||||||
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
|
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
|
||||||
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
// map model name
|
// map model name
|
||||||
@ -78,7 +79,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
modelMap := make(map[string]string)
|
modelMap := make(map[string]string)
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if modelMap[imageRequest.Model] != "" {
|
if modelMap[imageRequest.Model] != "" {
|
||||||
imageRequest.Model = modelMap[imageRequest.Model]
|
imageRequest.Model = modelMap[imageRequest.Model]
|
||||||
@ -90,10 +91,10 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if c.GetString("base_url") != "" {
|
if c.GetString("base_url") != "" {
|
||||||
baseURL = c.GetString("base_url")
|
baseURL = c.GetString("base_url")
|
||||||
}
|
}
|
||||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
|
||||||
if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
|
if channelType == common.ChannelTypeAzure && relayMode == relayconstant.RelayModeImagesGenerations {
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
||||||
apiVersion := common.GetAPIVersion(c)
|
apiVersion := relaycommon.GetAPIVersion(c)
|
||||||
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
|
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
|
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
|
||||||
}
|
}
|
||||||
@ -101,7 +102,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
|
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
|
||||||
jsonStr, err := json.Marshal(imageRequest)
|
jsonStr, err := json.Marshal(imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
} else {
|
} else {
|
||||||
@ -136,12 +137,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
|
quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
|
||||||
|
|
||||||
if consumeQuota && userQuota-quota < 0 {
|
if consumeQuota && userQuota-quota < 0 {
|
||||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
token := c.Request.Header.Get("Authorization")
|
token := c.Request.Header.Get("Authorization")
|
||||||
@ -154,25 +155,25 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
|
|
||||||
resp, err := controller.httpClient.Do(req)
|
resp, err := service.GetHttpClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = req.Body.Close()
|
err = req.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = c.Request.Body.Close()
|
err = c.Request.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return relayErrorHandler(resp)
|
return relaycommon.RelayErrorHandler(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
var textResponse ImageResponse
|
var textResponse dto.ImageResponse
|
||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
useTimeSeconds := time.Now().Unix() - startTime.Unix()
|
useTimeSeconds := time.Now().Unix() - startTime.Unix()
|
||||||
if consumeQuota {
|
if consumeQuota {
|
||||||
@ -202,15 +203,15 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
err = json.Unmarshal(responseBody, &textResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
@ -223,11 +224,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
_, err = io.Copy(c.Writer, resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -9,8 +9,10 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/controller"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
|
"one-api/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -105,11 +107,11 @@ func RelayMidjourneyImage(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
|
func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
|
||||||
var midjRequest Midjourney
|
var midjRequest Midjourney
|
||||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "bind_request_body_failed",
|
Description: "bind_request_body_failed",
|
||||||
Properties: nil,
|
Properties: nil,
|
||||||
@ -118,7 +120,7 @@ func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
|
|||||||
}
|
}
|
||||||
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
|
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
|
||||||
if midjourneyTask == nil {
|
if midjourneyTask == nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "midjourney_task_not_found",
|
Description: "midjourney_task_not_found",
|
||||||
Properties: nil,
|
Properties: nil,
|
||||||
@ -136,7 +138,7 @@ func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
|
|||||||
midjourneyTask.FailReason = midjRequest.FailReason
|
midjourneyTask.FailReason = midjRequest.FailReason
|
||||||
err = midjourneyTask.Update()
|
err = midjourneyTask.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "update_midjourney_task_failed",
|
Description: "update_midjourney_task_failed",
|
||||||
}
|
}
|
||||||
@ -168,16 +170,16 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
|
func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
var err error
|
var err error
|
||||||
var respBody []byte
|
var respBody []byte
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case RelayModeMidjourneyTaskFetch:
|
case relayconstant.RelayModeMidjourneyTaskFetch:
|
||||||
taskId := c.Param("id")
|
taskId := c.Param("id")
|
||||||
originTask := model.GetByMJId(userId, taskId)
|
originTask := model.GetByMJId(userId, taskId)
|
||||||
if originTask == nil {
|
if originTask == nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "task_no_found",
|
Description: "task_no_found",
|
||||||
}
|
}
|
||||||
@ -185,18 +187,18 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
midjourneyTask := getMidjourneyTaskModel(c, originTask)
|
midjourneyTask := getMidjourneyTaskModel(c, originTask)
|
||||||
respBody, err = json.Marshal(midjourneyTask)
|
respBody, err = json.Marshal(midjourneyTask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "unmarshal_response_body_failed",
|
Description: "unmarshal_response_body_failed",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case RelayModeMidjourneyTaskFetchByCondition:
|
case relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
||||||
var condition = struct {
|
var condition = struct {
|
||||||
IDs []string `json:"ids"`
|
IDs []string `json:"ids"`
|
||||||
}{}
|
}{}
|
||||||
err = c.BindJSON(&condition)
|
err = c.BindJSON(&condition)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "do_request_failed",
|
Description: "do_request_failed",
|
||||||
}
|
}
|
||||||
@ -214,7 +216,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
}
|
}
|
||||||
respBody, err = json.Marshal(tasks)
|
respBody, err = json.Marshal(tasks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "unmarshal_response_body_failed",
|
Description: "unmarshal_response_body_failed",
|
||||||
}
|
}
|
||||||
@ -225,7 +227,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
|
|
||||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "copy_response_body_failed",
|
Description: "copy_response_body_failed",
|
||||||
}
|
}
|
||||||
@ -245,7 +247,7 @@ const (
|
|||||||
MJSubmitActionUpscale = "UPSCALE" // 放大
|
MJSubmitActionUpscale = "UPSCALE" // 放大
|
||||||
)
|
)
|
||||||
|
|
||||||
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
||||||
imageModel := "midjourney"
|
imageModel := "midjourney"
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
@ -254,60 +256,60 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
consumeQuota := c.GetBool("consume_quota")
|
consumeQuota := c.GetBool("consume_quota")
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
var midjRequest MidjourneyRequest
|
var midjRequest dto.MidjourneyRequest
|
||||||
if consumeQuota {
|
if consumeQuota {
|
||||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "bind_request_body_failed",
|
Description: "bind_request_body_failed",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
||||||
if midjRequest.Prompt == "" {
|
if midjRequest.Prompt == "" {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "prompt_is_required",
|
Description: "prompt_is_required",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
midjRequest.Action = "IMAGINE"
|
midjRequest.Action = "IMAGINE"
|
||||||
} else if relayMode == RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
||||||
midjRequest.Action = "DESCRIBE"
|
midjRequest.Action = "DESCRIBE"
|
||||||
} else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
||||||
midjRequest.Action = "BLEND"
|
midjRequest.Action = "BLEND"
|
||||||
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
||||||
mjId := ""
|
mjId := ""
|
||||||
if relayMode == RelayModeMidjourneyChange {
|
if relayMode == relayconstant.RelayModeMidjourneyChange {
|
||||||
if midjRequest.TaskId == "" {
|
if midjRequest.TaskId == "" {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "taskId_is_required",
|
Description: "taskId_is_required",
|
||||||
}
|
}
|
||||||
} else if midjRequest.Action == "" {
|
} else if midjRequest.Action == "" {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "action_is_required",
|
Description: "action_is_required",
|
||||||
}
|
}
|
||||||
} else if midjRequest.Index == 0 {
|
} else if midjRequest.Index == 0 {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "index_can_only_be_1_2_3_4",
|
Description: "index_can_only_be_1_2_3_4",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//action = midjRequest.Action
|
//action = midjRequest.Action
|
||||||
mjId = midjRequest.TaskId
|
mjId = midjRequest.TaskId
|
||||||
} else if relayMode == RelayModeMidjourneySimpleChange {
|
} else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
|
||||||
if midjRequest.Content == "" {
|
if midjRequest.Content == "" {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "content_is_required",
|
Description: "content_is_required",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
params := convertSimpleChangeParams(midjRequest.Content)
|
params := convertSimpleChangeParams(midjRequest.Content)
|
||||||
if params == nil {
|
if params == nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "content_parse_failed",
|
Description: "content_parse_failed",
|
||||||
}
|
}
|
||||||
@ -318,25 +320,25 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
|
|
||||||
originTask := model.GetByMJId(userId, mjId)
|
originTask := model.GetByMJId(userId, mjId)
|
||||||
if originTask == nil {
|
if originTask == nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "task_no_found",
|
Description: "task_no_found",
|
||||||
}
|
}
|
||||||
} else if originTask.Action == "UPSCALE" {
|
} else if originTask.Action == "UPSCALE" {
|
||||||
//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
|
//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "upscale_task_can_not_be_change",
|
Description: "upscale_task_can_not_be_change",
|
||||||
}
|
}
|
||||||
} else if originTask.Status != "SUCCESS" {
|
} else if originTask.Status != "SUCCESS" {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "task_status_is_not_success",
|
Description: "task_status_is_not_success",
|
||||||
}
|
}
|
||||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||||
channel, err := model.GetChannelById(originTask.ChannelId, false)
|
channel, err := model.GetChannelById(originTask.ChannelId, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "channel_not_found",
|
Description: "channel_not_found",
|
||||||
}
|
}
|
||||||
@ -356,7 +358,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "unmarshal_model_mapping_failed",
|
Description: "unmarshal_model_mapping_failed",
|
||||||
}
|
}
|
||||||
@ -383,7 +385,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
if isModelMapped {
|
if isModelMapped {
|
||||||
jsonStr, err := json.Marshal(midjRequest)
|
jsonStr, err := json.Marshal(midjRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "marshal_text_request_failed",
|
Description: "marshal_text_request_failed",
|
||||||
}
|
}
|
||||||
@ -407,7 +409,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
ratio := modelPrice * groupRatio
|
ratio := modelPrice * groupRatio
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
userQuota, err := model.CacheGetUserQuota(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: err.Error(),
|
Description: err.Error(),
|
||||||
}
|
}
|
||||||
@ -415,7 +417,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
quota := int(ratio * common.QuotaPerUnit)
|
quota := int(ratio * common.QuotaPerUnit)
|
||||||
|
|
||||||
if consumeQuota && userQuota-quota < 0 {
|
if consumeQuota && userQuota-quota < 0 {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "quota_not_enough",
|
Description: "quota_not_enough",
|
||||||
}
|
}
|
||||||
@ -423,7 +425,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "create_request_failed",
|
Description: "create_request_failed",
|
||||||
}
|
}
|
||||||
@ -442,9 +444,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
log.Printf("request header: %s", req.Header)
|
log.Printf("request header: %s", req.Header)
|
||||||
log.Printf("request body: %s", midjRequest.Prompt)
|
log.Printf("request body: %s", midjRequest.Prompt)
|
||||||
|
|
||||||
resp, err := controller.httpClient.Do(req)
|
resp, err := service.GetHttpClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "do_request_failed",
|
Description: "do_request_failed",
|
||||||
}
|
}
|
||||||
@ -452,19 +454,19 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
|
|
||||||
err = req.Body.Close()
|
err = req.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "close_request_body_failed",
|
Description: "close_request_body_failed",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = c.Request.Body.Close()
|
err = c.Request.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "close_request_body_failed",
|
Description: "close_request_body_failed",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var midjResponse MidjourneyResponse
|
var midjResponse dto.MidjourneyResponse
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
if consumeQuota {
|
if consumeQuota {
|
||||||
@ -493,14 +495,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "read_response_body_failed",
|
Description: "read_response_body_failed",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "close_response_body_failed",
|
Description: "close_response_body_failed",
|
||||||
}
|
}
|
||||||
@ -510,13 +512,13 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
log.Printf("responseBody: %s", string(responseBody))
|
log.Printf("responseBody: %s", string(responseBody))
|
||||||
log.Printf("midjResponse: %v", midjResponse)
|
log.Printf("midjResponse: %v", midjResponse)
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
|
Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "unmarshal_response_body_failed",
|
Description: "unmarshal_response_body_failed",
|
||||||
}
|
}
|
||||||
@ -579,7 +581,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
|
|
||||||
err = midjourneyTask.Insert()
|
err = midjourneyTask.Insert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "insert_midjourney_task_failed",
|
Description: "insert_midjourney_task_failed",
|
||||||
}
|
}
|
||||||
@ -600,14 +602,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
|||||||
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
_, err = io.Copy(c.Writer, resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "copy_response_body_failed",
|
Description: "copy_response_body_failed",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
Code: 4,
|
||||||
Description: "close_response_body_failed",
|
Description: "close_response_body_failed",
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relaychannel "one-api/relay/channel"
|
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@ -119,7 +118,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
|
|
||||||
adaptor := relaychannel.GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
41
relay/relay_adaptor.go
Normal file
41
relay/relay_adaptor.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/relay/channel"
|
||||||
|
"one-api/relay/channel/ali"
|
||||||
|
"one-api/relay/channel/baidu"
|
||||||
|
"one-api/relay/channel/claude"
|
||||||
|
"one-api/relay/channel/gemini"
|
||||||
|
"one-api/relay/channel/openai"
|
||||||
|
"one-api/relay/channel/palm"
|
||||||
|
"one-api/relay/channel/tencent"
|
||||||
|
"one-api/relay/channel/xunfei"
|
||||||
|
"one-api/relay/channel/zhipu"
|
||||||
|
"one-api/relay/constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetAdaptor(apiType int) channel.Adaptor {
|
||||||
|
switch apiType {
|
||||||
|
//case constant.APITypeAIProxyLibrary:
|
||||||
|
// return &aiproxy.Adaptor{}
|
||||||
|
case constant.APITypeAli:
|
||||||
|
return &ali.Adaptor{}
|
||||||
|
case constant.APITypeAnthropic:
|
||||||
|
return &claude.Adaptor{}
|
||||||
|
case constant.APITypeBaidu:
|
||||||
|
return &baidu.Adaptor{}
|
||||||
|
case constant.APITypeGemini:
|
||||||
|
return &gemini.Adaptor{}
|
||||||
|
case constant.APITypeOpenAI:
|
||||||
|
return &openai.Adaptor{}
|
||||||
|
case constant.APITypePaLM:
|
||||||
|
return &palm.Adaptor{}
|
||||||
|
case constant.APITypeTencent:
|
||||||
|
return &tencent.Adaptor{}
|
||||||
|
case constant.APITypeXunfei:
|
||||||
|
return &xunfei.Adaptor{}
|
||||||
|
case constant.APITypeZhipu:
|
||||||
|
return &zhipu.Adaptor{}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -201,7 +201,7 @@ func CountTokenInput(input any, model string) int {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func countAudioToken(text string, model string) int {
|
func CountAudioToken(text string, model string) int {
|
||||||
if strings.HasPrefix(model, "tts") {
|
if strings.HasPrefix(model, "tts") {
|
||||||
return utf8.RuneCountInString(text)
|
return utf8.RuneCountInString(text)
|
||||||
} else {
|
} else {
|
||||||
|
Loading…
Reference in New Issue
Block a user