new-api/service/midjourney.go
2024-03-14 22:17:03 +08:00

234 lines
8.0 KiB
Go

package service
import (
"context"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"log"
"net/http"
"one-api/constant"
"one-api/dto"
relayconstant "one-api/relay/constant"
"strconv"
"strings"
"time"
)
func CoverActionToModelName(mjAction string) string {
modelName := "mj_" + strings.ToLower(mjAction)
if mjAction == constant.MjActionSwapFace {
modelName = "swap_face"
}
return modelName
}
func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) {
action := ""
if relayMode == relayconstant.RelayModeMidjourneyAction {
// plus request
err := CoverPlusActionToNormalAction(midjRequest)
if err != nil {
return "", err, false
}
action = midjRequest.Action
} else {
switch relayMode {
case relayconstant.RelayModeMidjourneyImagine:
action = constant.MjActionImagine
case relayconstant.RelayModeMidjourneyDescribe:
action = constant.MjActionDescribe
case relayconstant.RelayModeMidjourneyBlend:
action = constant.MjActionBlend
case relayconstant.RelayModeMidjourneyShorten:
action = constant.MjActionShorten
case relayconstant.RelayModeMidjourneyChange:
action = midjRequest.Action
case relayconstant.RelayModeMidjourneyModal:
action = constant.MjActionModal
case relayconstant.RelayModeSwapFace:
action = constant.MjActionSwapFace
case relayconstant.RelayModeMidjourneySimpleChange:
params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false
}
action = params.Action
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
return "", nil, true
default:
return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false
}
}
modelName := CoverActionToModelName(action)
return modelName, nil, true
}
func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
customId := midjRequest.CustomId
if customId == "" {
return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
}
splits := strings.Split(customId, "::")
var action string
if splits[1] == "JOB" {
action = splits[2]
} else {
action = splits[1]
}
if action == "" {
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
}
if strings.Contains(action, "upsample") {
index, err := strconv.Atoi(splits[3])
if err != nil {
return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
}
midjRequest.Index = index
midjRequest.Action = constant.MjActionUpscale
} else if strings.Contains(action, "variation") {
midjRequest.Index = 1
if action == "variation" {
index, err := strconv.Atoi(splits[3])
if err != nil {
return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
}
midjRequest.Index = index
midjRequest.Action = constant.MjActionVariation
} else if action == "low_variation" {
midjRequest.Action = constant.MjActionLowVariation
} else if action == "high_variation" {
midjRequest.Action = constant.MjActionHighVariation
}
} else if strings.Contains(action, "pan") {
midjRequest.Action = constant.MjActionPan
midjRequest.Index = 1
} else if strings.Contains(action, "reroll") {
midjRequest.Action = constant.MjActionReRoll
midjRequest.Index = 1
} else if action == "Outpaint" {
midjRequest.Action = constant.MjActionZoom
midjRequest.Index = 1
} else if action == "CustomZoom" {
midjRequest.Action = constant.MjActionCustomZoom
midjRequest.Index = 1
} else if action == "Inpaint" {
midjRequest.Action = constant.MjActionInPaint
midjRequest.Index = 1
} else {
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
}
return nil
}
func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
split := strings.Split(content, " ")
if len(split) != 2 {
return nil
}
action := strings.ToLower(split[1])
changeParams := &dto.MidjourneyRequest{}
changeParams.TaskId = split[0]
if action[0] == 'u' {
changeParams.Action = "UPSCALE"
} else if action[0] == 'v' {
changeParams.Action = "VARIATION"
} else if action == "r" {
changeParams.Action = "REROLL"
return changeParams
} else {
return nil
}
index, err := strconv.Atoi(action[1:2])
if err != nil || index < 1 || index > 4 {
return nil
}
changeParams.Index = index
return changeParams
}
func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
var nullBytes []byte
//var requestBody io.Reader
//requestBody = c.Request.Body
// read request body to json, delete accountFilter and notifyHook
var mapResult map[string]interface{}
// if get request, no need to read request body
if c.Request.Method != "GET" {
err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
}
delete(mapResult, "accountFilter")
if !constant.MjNotifyEnabled {
delete(mapResult, "notifyHook")
}
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
// make new request with mapResult
}
reqBody, err := json.Marshal(mapResult)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
defer cancel()
resp, err := GetHttpClient().Do(req)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
}
statusCode := resp.StatusCode
//if statusCode != 200 {
// return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil
//}
err = req.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
}
err = c.Request.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
}
var midjResponse dto.MidjourneyResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
}
err = resp.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
}
respStr := string(responseBody)
log.Printf("responseBody: %s", respStr)
if respStr == "" {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
} else {
err = json.Unmarshal(responseBody, &midjResponse)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
}
}
//log.Printf("midjResponse: %v", midjResponse)
//for k, v := range resp.Header {
// c.Writer.Header().Set(k, v[0])
//}
return &dto.MidjourneyResponseWithStatusCode{
StatusCode: statusCode,
Response: midjResponse,
}, responseBody, nil
}