feat: add Midjourney (#138)

* 🚧 stash

*  feat: add Midjourney

* 📝 doc: update readme
This commit is contained in:
Buer
2024-04-05 04:03:46 +08:00
committed by GitHub
parent 87bfecf3e9
commit c1fc32add7
42 changed files with 2479 additions and 84 deletions

View File

@@ -0,0 +1,121 @@
package midjourney
import (
"bytes"
"context"
"encoding/json"
"io"
"log"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"time"
)
// 定义供应商工厂
type MidjourneyProviderFactory struct{}
// 创建 MidjourneyProvider
func (f MidjourneyProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &MidjourneyProvider{
BaseProvider: base.BaseProvider{
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, nil),
},
}
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "",
}
}
type MidjourneyProvider struct {
base.BaseProvider
}
func (p *MidjourneyProvider) Send(timeout int, requestURL string) (*MidjourneyResponseWithStatusCode, []byte, error) {
var nullBytes []byte
var mapResult map[string]interface{}
if p.Context.Request.Method != "GET" {
err := json.NewDecoder(p.Context.Request.Body).Decode(&mapResult)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
}
delete(mapResult, "accountFilter")
if !common.MjNotifyEnabled {
delete(mapResult, "notifyHook")
}
}
reqBody, err := json.Marshal(mapResult)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
}
fullRequestURL := p.GetFullRequestURL(requestURL, "")
var cancel context.CancelFunc
p.Requester.Context, cancel = context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
headers := p.GetRequestHeaders()
defer cancel()
req, err := p.Requester.NewRequest(p.Context.Request.Method, fullRequestURL, p.Requester.WithBody(bytes.NewBuffer(reqBody)), p.Requester.WithHeader(headers))
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
}
resp, errWith := p.Requester.SendRequestRaw(req)
if errWith != nil {
common.SysError("do request failed: " + errWith.Error())
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
}
statusCode := resp.StatusCode
err = req.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
}
err = p.Context.Request.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
}
var midjResponse MidjourneyResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
}
err = resp.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
}
respStr := string(responseBody)
log.Printf("responseBody: %s", respStr)
if respStr == "" {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
} else {
err = json.Unmarshal(responseBody, &midjResponse)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
}
}
return &MidjourneyResponseWithStatusCode{
StatusCode: statusCode,
Response: midjResponse,
}, responseBody, nil
}
func (p *MidjourneyProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
headers["mj-api-secret"] = p.Channel.Key
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
headers["Accept"] = p.Context.Request.Header.Get("Accept")
return headers
}

View File

@@ -0,0 +1,69 @@
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: relay/constant/relay_mode.go
package midjourney
const (
RelayModeUnknown = iota
RelayModeMidjourneyImagine
RelayModeMidjourneyDescribe
RelayModeMidjourneyBlend
RelayModeMidjourneyChange
RelayModeMidjourneySimpleChange
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch
RelayModeMidjourneyTaskImageSeed
RelayModeMidjourneyTaskFetchByCondition
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
RelayModeMidjourneyAction
RelayModeMidjourneyModal
RelayModeMidjourneyShorten
RelayModeMidjourneySwapFace
)
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: constant/midjourney.go
const (
MjErrorUnknown = 5
MjRequestError = 4
)
const (
MjActionImagine = "IMAGINE"
MjActionDescribe = "DESCRIBE"
MjActionBlend = "BLEND"
MjActionUpscale = "UPSCALE"
MjActionVariation = "VARIATION"
MjActionReRoll = "REROLL"
MjActionInPaint = "INPAINT"
MjActionModal = "MODAL"
MjActionZoom = "ZOOM"
MjActionCustomZoom = "CUSTOM_ZOOM"
MjActionShorten = "SHORTEN"
MjActionHighVariation = "HIGH_VARIATION"
MjActionLowVariation = "LOW_VARIATION"
MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE"
)
var MidjourneyModel2Action = map[string]string{
"mj_imagine": MjActionImagine,
"mj_describe": MjActionDescribe,
"mj_blend": MjActionBlend,
"mj_upscale": MjActionUpscale,
"mj_variation": MjActionVariation,
"mj_reroll": MjActionReRoll,
"mj_modal": MjActionModal,
"mj_inpaint": MjActionInPaint,
"mj_zoom": MjActionZoom,
"mj_custom_zoom": MjActionCustomZoom,
"mj_shorten": MjActionShorten,
"mj_high_variation": MjActionHighVariation,
"mj_low_variation": MjActionLowVariation,
"mj_pan": MjActionPan,
"swap_face": MjActionSwapFace,
}

View File

@@ -0,0 +1,18 @@
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: service/error.go
package midjourney
func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *MidjourneyResponseWithStatusCode {
return &MidjourneyResponseWithStatusCode{
StatusCode: statusCode,
Response: *MidjourneyErrorWrapper(code, desc),
}
}
func MidjourneyErrorWrapper(code int, desc string) *MidjourneyResponse {
return &MidjourneyResponse{
Code: code,
Description: desc,
}
}

View File

@@ -0,0 +1,92 @@
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: dto/midjourney.go
package midjourney
type SwapFaceRequest struct {
SourceBase64 string `json:"sourceBase64"`
TargetBase64 string `json:"targetBase64"`
}
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
CustomId string `json:"customId"`
BotType string `json:"botType"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
State string `json:"state"`
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
MaskBase64 string `json:"maskBase64"`
}
type MidjourneyResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Properties interface{} `json:"properties"`
Result string `json:"result"`
Type string `json:"type,omitempty"`
}
type MidjourneyResponseWithStatusCode struct {
StatusCode int `json:"statusCode"`
Response MidjourneyResponse
}
type MidjourneyDto struct {
MjId string `json:"id"`
Action string `json:"action"`
CustomId string `json:"customId"`
BotType string `json:"botType"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"failReason"`
Buttons any `json:"buttons"`
MaskBase64 string `json:"maskBase64"`
Properties *Properties `json:"properties"`
}
type MidjourneyStatus struct {
Status int `json:"status"`
}
type MidjourneyWithoutStatus struct {
Id int `json:"id"`
Code int `json:"code"`
UserId int `json:"user_id" gorm:"index"`
Action string `json:"action"`
MjId string `json:"mj_id" gorm:"index"`
Prompt string `json:"prompt"`
PromptEn string `json:"prompt_en"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
ImageUrl string `json:"image_url"`
Progress string `json:"progress"`
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
}
type ActionButton struct {
CustomId any `json:"customId"`
Emoji any `json:"emoji"`
Label any `json:"label"`
Type any `json:"type"`
Style any `json:"style"`
}
type Properties struct {
FinalPrompt string `json:"finalPrompt"`
FinalZhPrompt string `json:"finalZhPrompt"`
}