mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-18 00:16:37 +08:00
251 lines
8.5 KiB
Go
251 lines
8.5 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"github.com/gin-gonic/gin"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/constant"
|
|
"one-api/dto"
|
|
relayconstant "one-api/relay/constant"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
func CoverActionToModelName(mjAction string) string {
|
|
modelName := "mj_" + strings.ToLower(mjAction)
|
|
if mjAction == constant.MjActionSwapFace {
|
|
modelName = "swap_face"
|
|
}
|
|
return modelName
|
|
}
|
|
|
|
func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) {
|
|
action := ""
|
|
if relayMode == relayconstant.RelayModeMidjourneyAction {
|
|
// plus request
|
|
err := CoverPlusActionToNormalAction(midjRequest)
|
|
if err != nil {
|
|
return "", err, false
|
|
}
|
|
action = midjRequest.Action
|
|
} else {
|
|
switch relayMode {
|
|
case relayconstant.RelayModeMidjourneyImagine:
|
|
action = constant.MjActionImagine
|
|
case relayconstant.RelayModeMidjourneyDescribe:
|
|
action = constant.MjActionDescribe
|
|
case relayconstant.RelayModeMidjourneyBlend:
|
|
action = constant.MjActionBlend
|
|
case relayconstant.RelayModeMidjourneyShorten:
|
|
action = constant.MjActionShorten
|
|
case relayconstant.RelayModeMidjourneyChange:
|
|
action = midjRequest.Action
|
|
case relayconstant.RelayModeMidjourneyModal:
|
|
action = constant.MjActionModal
|
|
case relayconstant.RelayModeSwapFace:
|
|
action = constant.MjActionSwapFace
|
|
case relayconstant.RelayModeMidjourneySimpleChange:
|
|
params := ConvertSimpleChangeParams(midjRequest.Content)
|
|
if params == nil {
|
|
return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false
|
|
}
|
|
action = params.Action
|
|
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
|
|
return "", nil, true
|
|
default:
|
|
return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false
|
|
}
|
|
}
|
|
modelName := CoverActionToModelName(action)
|
|
return modelName, nil, true
|
|
}
|
|
|
|
func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
|
|
// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
|
|
customId := midjRequest.CustomId
|
|
if customId == "" {
|
|
return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
|
|
}
|
|
splits := strings.Split(customId, "::")
|
|
var action string
|
|
if splits[1] == "JOB" {
|
|
action = splits[2]
|
|
} else {
|
|
action = splits[1]
|
|
}
|
|
|
|
if action == "" {
|
|
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
|
|
}
|
|
if strings.Contains(action, "upsample") {
|
|
index, err := strconv.Atoi(splits[3])
|
|
if err != nil {
|
|
return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
|
}
|
|
midjRequest.Index = index
|
|
midjRequest.Action = constant.MjActionUpscale
|
|
} else if strings.Contains(action, "variation") {
|
|
midjRequest.Index = 1
|
|
if action == "variation" {
|
|
index, err := strconv.Atoi(splits[3])
|
|
if err != nil {
|
|
return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
|
}
|
|
midjRequest.Index = index
|
|
midjRequest.Action = constant.MjActionVariation
|
|
} else if action == "low_variation" {
|
|
midjRequest.Action = constant.MjActionLowVariation
|
|
} else if action == "high_variation" {
|
|
midjRequest.Action = constant.MjActionHighVariation
|
|
}
|
|
} else if strings.Contains(action, "pan") {
|
|
midjRequest.Action = constant.MjActionPan
|
|
midjRequest.Index = 1
|
|
} else if strings.Contains(action, "reroll") {
|
|
midjRequest.Action = constant.MjActionReRoll
|
|
midjRequest.Index = 1
|
|
} else if action == "Outpaint" {
|
|
midjRequest.Action = constant.MjActionZoom
|
|
midjRequest.Index = 1
|
|
} else if action == "CustomZoom" {
|
|
midjRequest.Action = constant.MjActionCustomZoom
|
|
midjRequest.Index = 1
|
|
} else if action == "Inpaint" {
|
|
midjRequest.Action = constant.MjActionInPaint
|
|
midjRequest.Index = 1
|
|
} else {
|
|
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
|
|
split := strings.Split(content, " ")
|
|
if len(split) != 2 {
|
|
return nil
|
|
}
|
|
|
|
action := strings.ToLower(split[1])
|
|
changeParams := &dto.MidjourneyRequest{}
|
|
changeParams.TaskId = split[0]
|
|
|
|
if action[0] == 'u' {
|
|
changeParams.Action = "UPSCALE"
|
|
} else if action[0] == 'v' {
|
|
changeParams.Action = "VARIATION"
|
|
} else if action == "r" {
|
|
changeParams.Action = "REROLL"
|
|
return changeParams
|
|
} else {
|
|
return nil
|
|
}
|
|
|
|
index, err := strconv.Atoi(action[1:2])
|
|
if err != nil || index < 1 || index > 4 {
|
|
return nil
|
|
}
|
|
changeParams.Index = index
|
|
return changeParams
|
|
}
|
|
|
|
func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
|
|
var nullBytes []byte
|
|
//var requestBody io.Reader
|
|
//requestBody = c.Request.Body
|
|
// read request body to json, delete accountFilter and notifyHook
|
|
var mapResult map[string]interface{}
|
|
// if get request, no need to read request body
|
|
if c.Request.Method != "GET" {
|
|
err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
|
|
if err != nil {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
|
}
|
|
if !constant.MjAccountFilterEnabled {
|
|
delete(mapResult, "accountFilter")
|
|
}
|
|
if !constant.MjNotifyEnabled {
|
|
delete(mapResult, "notifyHook")
|
|
}
|
|
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
// make new request with mapResult
|
|
}
|
|
if constant.MjModeClearEnabled {
|
|
if prompt, ok := mapResult["prompt"].(string); ok {
|
|
prompt = strings.Replace(prompt, "--fast", "", -1)
|
|
prompt = strings.Replace(prompt, "--relax", "", -1)
|
|
prompt = strings.Replace(prompt, "--turbo", "", -1)
|
|
|
|
mapResult["prompt"] = prompt
|
|
}
|
|
}
|
|
reqBody, err := json.Marshal(mapResult)
|
|
if err != nil {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
|
}
|
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
|
|
if err != nil {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
// 使用带有超时的 context 创建新的请求
|
|
req = req.WithContext(ctx)
|
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
auth := c.Request.Header.Get("Authorization")
|
|
if auth != "" {
|
|
auth = strings.TrimPrefix(auth, "Bearer ")
|
|
req.Header.Set("mj-api-secret", auth)
|
|
}
|
|
defer cancel()
|
|
resp, err := GetHttpClient().Do(req)
|
|
if err != nil {
|
|
common.SysError("do request failed: " + err.Error())
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
|
|
}
|
|
statusCode := resp.StatusCode
|
|
//if statusCode != 200 {
|
|
// return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil
|
|
//}
|
|
err = req.Body.Close()
|
|
if err != nil {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
|
}
|
|
err = c.Request.Body.Close()
|
|
if err != nil {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
|
}
|
|
var midjResponse dto.MidjourneyResponse
|
|
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
|
|
}
|
|
err = resp.Body.Close()
|
|
if err != nil {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
|
|
}
|
|
respStr := string(responseBody)
|
|
log.Printf("responseBody: %s", respStr)
|
|
if respStr == "" {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
|
|
} else {
|
|
err = json.Unmarshal(responseBody, &midjResponse)
|
|
if err != nil {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
|
|
}
|
|
}
|
|
//log.Printf("midjResponse: %v", midjResponse)
|
|
//for k, v := range resp.Header {
|
|
// c.Writer.Header().Set(k, v[0])
|
|
//}
|
|
return &dto.MidjourneyResponseWithStatusCode{
|
|
StatusCode: statusCode,
|
|
Response: midjResponse,
|
|
}, responseBody, nil
|
|
}
|