mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-16 21:23:44 +08:00
✨ feat: add Midjourney (#138)
* 🚧 stash * ✨ feat: add Midjourney * 📝 doc: update readme
This commit is contained in:
121
providers/midjourney/base.go
Normal file
121
providers/midjourney/base.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package midjourney
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 定义供应商工厂
|
||||
type MidjourneyProviderFactory struct{}
|
||||
|
||||
// 创建 MidjourneyProvider
|
||||
func (f MidjourneyProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &MidjourneyProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(*channel.Proxy, nil),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "",
|
||||
}
|
||||
}
|
||||
|
||||
type MidjourneyProvider struct {
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
func (p *MidjourneyProvider) Send(timeout int, requestURL string) (*MidjourneyResponseWithStatusCode, []byte, error) {
|
||||
var nullBytes []byte
|
||||
var mapResult map[string]interface{}
|
||||
if p.Context.Request.Method != "GET" {
|
||||
err := json.NewDecoder(p.Context.Request.Body).Decode(&mapResult)
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
delete(mapResult, "accountFilter")
|
||||
if !common.MjNotifyEnabled {
|
||||
delete(mapResult, "notifyHook")
|
||||
}
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(mapResult)
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(requestURL, "")
|
||||
|
||||
var cancel context.CancelFunc
|
||||
p.Requester.Context, cancel = context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
|
||||
headers := p.GetRequestHeaders()
|
||||
defer cancel()
|
||||
|
||||
req, err := p.Requester.NewRequest(p.Context.Request.Method, fullRequestURL, p.Requester.WithBody(bytes.NewBuffer(reqBody)), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
|
||||
resp, errWith := p.Requester.SendRequestRaw(req)
|
||||
if errWith != nil {
|
||||
common.SysError("do request failed: " + errWith.Error())
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
statusCode := resp.StatusCode
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
||||
}
|
||||
err = p.Context.Request.Body.Close()
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
||||
}
|
||||
var midjResponse MidjourneyResponse
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
|
||||
}
|
||||
respStr := string(responseBody)
|
||||
log.Printf("responseBody: %s", respStr)
|
||||
if respStr == "" {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
|
||||
} else {
|
||||
err = json.Unmarshal(responseBody, &midjResponse)
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
|
||||
}
|
||||
}
|
||||
|
||||
return &MidjourneyResponseWithStatusCode{
|
||||
StatusCode: statusCode,
|
||||
Response: midjResponse,
|
||||
}, responseBody, nil
|
||||
|
||||
}
|
||||
|
||||
func (p *MidjourneyProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
headers["mj-api-secret"] = p.Channel.Key
|
||||
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||
|
||||
return headers
|
||||
}
|
||||
69
providers/midjourney/constant.go
Normal file
69
providers/midjourney/constant.go
Normal file
@@ -0,0 +1,69 @@
|
||||
// Author: Calcium-Ion
|
||||
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||
// Path: relay/constant/relay_mode.go
|
||||
package midjourney
|
||||
|
||||
const (
|
||||
RelayModeUnknown = iota
|
||||
RelayModeMidjourneyImagine
|
||||
RelayModeMidjourneyDescribe
|
||||
RelayModeMidjourneyBlend
|
||||
RelayModeMidjourneyChange
|
||||
RelayModeMidjourneySimpleChange
|
||||
RelayModeMidjourneyNotify
|
||||
RelayModeMidjourneyTaskFetch
|
||||
RelayModeMidjourneyTaskImageSeed
|
||||
RelayModeMidjourneyTaskFetchByCondition
|
||||
RelayModeAudioSpeech
|
||||
RelayModeAudioTranscription
|
||||
RelayModeAudioTranslation
|
||||
RelayModeMidjourneyAction
|
||||
RelayModeMidjourneyModal
|
||||
RelayModeMidjourneyShorten
|
||||
RelayModeMidjourneySwapFace
|
||||
)
|
||||
|
||||
// Author: Calcium-Ion
|
||||
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||
// Path: constant/midjourney.go
|
||||
|
||||
const (
|
||||
MjErrorUnknown = 5
|
||||
MjRequestError = 4
|
||||
)
|
||||
|
||||
const (
|
||||
MjActionImagine = "IMAGINE"
|
||||
MjActionDescribe = "DESCRIBE"
|
||||
MjActionBlend = "BLEND"
|
||||
MjActionUpscale = "UPSCALE"
|
||||
MjActionVariation = "VARIATION"
|
||||
MjActionReRoll = "REROLL"
|
||||
MjActionInPaint = "INPAINT"
|
||||
MjActionModal = "MODAL"
|
||||
MjActionZoom = "ZOOM"
|
||||
MjActionCustomZoom = "CUSTOM_ZOOM"
|
||||
MjActionShorten = "SHORTEN"
|
||||
MjActionHighVariation = "HIGH_VARIATION"
|
||||
MjActionLowVariation = "LOW_VARIATION"
|
||||
MjActionPan = "PAN"
|
||||
MjActionSwapFace = "SWAP_FACE"
|
||||
)
|
||||
|
||||
var MidjourneyModel2Action = map[string]string{
|
||||
"mj_imagine": MjActionImagine,
|
||||
"mj_describe": MjActionDescribe,
|
||||
"mj_blend": MjActionBlend,
|
||||
"mj_upscale": MjActionUpscale,
|
||||
"mj_variation": MjActionVariation,
|
||||
"mj_reroll": MjActionReRoll,
|
||||
"mj_modal": MjActionModal,
|
||||
"mj_inpaint": MjActionInPaint,
|
||||
"mj_zoom": MjActionZoom,
|
||||
"mj_custom_zoom": MjActionCustomZoom,
|
||||
"mj_shorten": MjActionShorten,
|
||||
"mj_high_variation": MjActionHighVariation,
|
||||
"mj_low_variation": MjActionLowVariation,
|
||||
"mj_pan": MjActionPan,
|
||||
"swap_face": MjActionSwapFace,
|
||||
}
|
||||
18
providers/midjourney/error.go
Normal file
18
providers/midjourney/error.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Author: Calcium-Ion
|
||||
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||
// Path: service/error.go
|
||||
package midjourney
|
||||
|
||||
func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *MidjourneyResponseWithStatusCode {
|
||||
return &MidjourneyResponseWithStatusCode{
|
||||
StatusCode: statusCode,
|
||||
Response: *MidjourneyErrorWrapper(code, desc),
|
||||
}
|
||||
}
|
||||
|
||||
func MidjourneyErrorWrapper(code int, desc string) *MidjourneyResponse {
|
||||
return &MidjourneyResponse{
|
||||
Code: code,
|
||||
Description: desc,
|
||||
}
|
||||
}
|
||||
92
providers/midjourney/type.go
Normal file
92
providers/midjourney/type.go
Normal file
@@ -0,0 +1,92 @@
|
||||
// Author: Calcium-Ion
|
||||
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||
// Path: dto/midjourney.go
|
||||
package midjourney
|
||||
|
||||
type SwapFaceRequest struct {
|
||||
SourceBase64 string `json:"sourceBase64"`
|
||||
TargetBase64 string `json:"targetBase64"`
|
||||
}
|
||||
|
||||
type MidjourneyRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
CustomId string `json:"customId"`
|
||||
BotType string `json:"botType"`
|
||||
NotifyHook string `json:"notifyHook"`
|
||||
Action string `json:"action"`
|
||||
Index int `json:"index"`
|
||||
State string `json:"state"`
|
||||
TaskId string `json:"taskId"`
|
||||
Base64Array []string `json:"base64Array"`
|
||||
Content string `json:"content"`
|
||||
MaskBase64 string `json:"maskBase64"`
|
||||
}
|
||||
|
||||
type MidjourneyResponse struct {
|
||||
Code int `json:"code"`
|
||||
Description string `json:"description"`
|
||||
Properties interface{} `json:"properties"`
|
||||
Result string `json:"result"`
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
type MidjourneyResponseWithStatusCode struct {
|
||||
StatusCode int `json:"statusCode"`
|
||||
Response MidjourneyResponse
|
||||
}
|
||||
|
||||
type MidjourneyDto struct {
|
||||
MjId string `json:"id"`
|
||||
Action string `json:"action"`
|
||||
CustomId string `json:"customId"`
|
||||
BotType string `json:"botType"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"promptEn"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submitTime"`
|
||||
StartTime int64 `json:"startTime"`
|
||||
FinishTime int64 `json:"finishTime"`
|
||||
ImageUrl string `json:"imageUrl"`
|
||||
Status string `json:"status"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"failReason"`
|
||||
Buttons any `json:"buttons"`
|
||||
MaskBase64 string `json:"maskBase64"`
|
||||
Properties *Properties `json:"properties"`
|
||||
}
|
||||
|
||||
type MidjourneyStatus struct {
|
||||
Status int `json:"status"`
|
||||
}
|
||||
type MidjourneyWithoutStatus struct {
|
||||
Id int `json:"id"`
|
||||
Code int `json:"code"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Action string `json:"action"`
|
||||
MjId string `json:"mj_id" gorm:"index"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"prompt_en"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submit_time"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
FinishTime int64 `json:"finish_time"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
ChannelId int `json:"channel_id"`
|
||||
}
|
||||
|
||||
type ActionButton struct {
|
||||
CustomId any `json:"customId"`
|
||||
Emoji any `json:"emoji"`
|
||||
Label any `json:"label"`
|
||||
Type any `json:"type"`
|
||||
Style any `json:"style"`
|
||||
}
|
||||
|
||||
type Properties struct {
|
||||
FinalPrompt string `json:"finalPrompt"`
|
||||
FinalZhPrompt string `json:"finalZhPrompt"`
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"one-api/providers/deepseek"
|
||||
"one-api/providers/gemini"
|
||||
"one-api/providers/groq"
|
||||
"one-api/providers/midjourney"
|
||||
"one-api/providers/minimax"
|
||||
"one-api/providers/mistral"
|
||||
"one-api/providers/openai"
|
||||
@@ -52,6 +53,7 @@ func init() {
|
||||
providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{}
|
||||
providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{}
|
||||
providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{}
|
||||
providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
|
||||
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user