feat: 兼容自定义变焦,完善modal操作

This commit is contained in:
CaIon
2024-03-14 16:42:37 +08:00
parent 614220a0fb
commit d704902b70
12 changed files with 147 additions and 110 deletions

View File

@@ -1,11 +1,18 @@
package service
import (
"context"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"log"
"net/http"
"one-api/constant"
"one-api/dto"
relayconstant "one-api/relay/constant"
"strconv"
"strings"
"time"
)
func CoverActionToModelName(mjAction string) string {
@@ -35,7 +42,7 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
case relayconstant.RelayModeMidjourneyChange:
action = midjRequest.Action
case relayconstant.RelayModeMidjourneyModal:
action = constant.MjActionInPaint
action = constant.MjActionModal
case relayconstant.RelayModeMidjourneySimpleChange:
params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
@@ -96,11 +103,14 @@ func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.Midj
} else if strings.Contains(action, "reroll") {
midjRequest.Action = constant.MjActionReRoll
midjRequest.Index = 1
} else if action == "Outpaint" || action == "CustomZoom" {
} else if action == "Outpaint" {
midjRequest.Action = constant.MjActionZoom
midjRequest.Index = 1
} else if action == "CustomZoom" {
midjRequest.Action = constant.MjActionCustomZoom
midjRequest.Index = 1
} else if action == "Inpaint" {
midjRequest.Action = constant.MjActionInPaintPre
midjRequest.Action = constant.MjActionInPaint
midjRequest.Index = 1
} else {
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
@@ -136,3 +146,60 @@ func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
changeParams.Index = index
return changeParams
}
func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string, midjRequest *dto.MidjourneyRequest) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
var nullBytes []byte
var requestBody io.Reader
requestBody = c.Request.Body
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
defer cancel()
resp, err := GetHttpClient().Do(req)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
}
statusCode := resp.StatusCode
//if statusCode != 200 {
// return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil
//}
err = req.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
}
err = c.Request.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
}
var midjResponse dto.MidjourneyResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
}
err = resp.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
}
err = json.Unmarshal(responseBody, &midjResponse)
log.Printf("responseBody: %s", string(responseBody))
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
}
//log.Printf("midjResponse: %v", midjResponse)
//for k, v := range resp.Header {
// c.Writer.Header().Set(k, v[0])
//}
return &dto.MidjourneyResponseWithStatusCode{
StatusCode: statusCode,
Response: midjResponse,
}, responseBody, nil
}