fix mj bug

This commit is contained in:
CaIon 2023-11-06 02:08:12 +08:00
parent 3d87f868a3
commit de596ce90c
3 changed files with 141 additions and 16 deletions

View File

@ -2,14 +2,17 @@ package controller
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"time"
)
@ -25,7 +28,9 @@ func UpdateMidjourneyTask() {
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) != 0 {
log.Printf("检测到未完成的任务数有: %v", len(tasks))
for _, task := range tasks {
log.Printf("未完成的任务信息: %v", task)
midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
if err != nil {
log.Printf("UpdateMidjourneyTask: %v", err)
@ -39,6 +44,7 @@ func UpdateMidjourneyTask() {
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
log.Printf("requestUrl: %s", requestUrl)
req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
if err != nil {
@ -46,7 +52,16 @@ func UpdateMidjourneyTask() {
continue
}
// 设置超时时间
timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer midjourney-proxy")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := httpClient.Do(req)
if err != nil {
@ -54,11 +69,37 @@ func UpdateMidjourneyTask() {
continue
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
log.Printf("responseBody: %s", string(responseBody))
var responseItem Midjourney
err = json.NewDecoder(resp.Body).Decode(&responseItem)
// err = json.NewDecoder(resp.Body).Decode(&responseItem)
err = json.Unmarshal(responseBody, &responseItem)
if err != nil {
log.Printf("UpdateMidjourneyTask error: %v", err)
continue
if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") {
var responseWithoutStatus MidjourneyWithoutStatus
var responseStatus MidjourneyStatus
err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
err2 := json.Unmarshal(responseBody, &responseStatus)
if err1 == nil && err2 == nil {
jsonData, err3 := json.Marshal(responseWithoutStatus)
if err3 != nil {
log.Fatalf("UpdateMidjourneyTask error1: %v", err3)
continue
}
err4 := json.Unmarshal(jsonData, &responseStatus)
if err4 != nil {
log.Fatalf("UpdateMidjourneyTask error2: %v", err4)
continue
}
responseItem.Status = strconv.Itoa(responseStatus.Status)
} else {
log.Printf("UpdateMidjourneyTask error3: %v", err)
continue
}
} else {
log.Printf("UpdateMidjourneyTask error4: %v", err)
continue
}
}
task.Code = 1
task.Progress = responseItem.Progress
@ -94,7 +135,7 @@ func UpdateMidjourneyTask() {
err = task.Update()
if err != nil {
log.Printf("UpdateMidjourneyTask error: %v", err)
log.Printf("UpdateMidjourneyTask error5: %v", err)
}
log.Printf("UpdateMidjourneyTask success: %v", task)
}

View File

@ -12,6 +12,7 @@ import (
"one-api/model"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
@ -32,6 +33,28 @@ type Midjourney struct {
FailReason string `json:"failReason"`
}
type MidjourneyStatus struct {
Status int `json:"status"`
}
type MidjourneyWithoutStatus struct {
Id int `json:"id"`
Code int `json:"code"`
UserId int `json:"user_id" gorm:"index"`
Action string `json:"action"`
MjId string `json:"mj_id" gorm:"index"`
Prompt string `json:"prompt"`
PromptEn string `json:"prompt_en"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
ImageUrl string `json:"image_url"`
Progress string `json:"progress"`
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
}
func RelayMidjourneyImage(c *gin.Context) {
taskId := c.Param("id")
midjourneyTask := model.GetByMJId(taskId)
@ -115,7 +138,13 @@ func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
midjourneyTask.SubmitTime = originTask.SubmitTime
midjourneyTask.StartTime = originTask.StartTime
midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" {
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
}
}
midjourneyTask.Status = originTask.Status
midjourneyTask.FailReason = originTask.FailReason
midjourneyTask.Action = originTask.Action
@ -157,7 +186,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
}
}
}
if relayMode == RelayModeMidjourneyImagine {
if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
return &MidjourneyResponse{
Code: 4,
@ -165,7 +194,11 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
}
}
midjRequest.Action = "IMAGINE"
} else if midjRequest.TaskId != "" {
} else if relayMode == RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
midjRequest.Action = "DESCRIBE"
} else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = "BLEND"
} else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果
originTask := model.GetByMJId(midjRequest.TaskId)
if originTask == nil {
return &MidjourneyResponse{
@ -183,7 +216,17 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
Code: 4,
Description: "task_status_is_not_success",
}
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
channel, err := model.GetChannelById(originTask.ChannelId, false)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "channel_not_found",
}
}
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
log.Printf("检测到此操作为放大、变换获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
}
midjRequest.Prompt = originTask.Prompt
} else if relayMode == RelayModeMidjourneyChange {
@ -234,6 +277,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
log.Printf("fullRequestURL: %s", fullRequestURL)
var requestBody io.Reader
if isModelMapped {
@ -275,14 +319,15 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
Description: "create_request_failed",
}
}
//req.HeaderBar.Set("Authorization", c.Request.HeaderBar.Get("Authorization"))
//req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
//mjToken := ""
//if c.Request.HeaderBar.Get("Authorization") != "" {
// mjToken = strings.Split(c.Request.HeaderBar.Get("Authorization"), " ")[1]
//if c.Request.Header.Get("Authorization") != "" {
// mjToken = strings.Split(c.Request.Header.Get("Authorization"), " ")[1]
//}
req.Header.Set("Authorization", "Bearer midjourney-proxy")
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
// print request header
log.Printf("request header: %s", req.Header)
@ -367,10 +412,14 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
Description: "unmarshal_response_body_failed",
}
}
if midjResponse.Code == 24 || midjResponse.Code == 21 || midjResponse.Code == 4 {
consumeQuota = false
}
// 文档https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
// 22-排队中 {"code":22,"description":"排队中前面还有1个任务","result":"0741798445574458","properties":{"numberOfQueues":1,"discordInstanceId":"1118138338562560102"}}
// 23-队列已满,请稍后再试 {"code":23,"description":"队列已满,请稍后尝试","result":"14001929738841620","properties":{"discordInstanceId":"1118138338562560102"}}
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
// other: 提交错误description为错误描述
midjourneyTask := &model.Midjourney{
UserId: userId,
Code: midjResponse.Code,
@ -380,7 +429,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
PromptEn: "",
Description: midjResponse.Description,
State: "",
SubmitTime: 0,
SubmitTime: time.Now().UnixNano() / int64(time.Millisecond),
StartTime: 0,
FinishTime: 0,
ImageUrl: "",
@ -389,9 +438,35 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
FailReason: "",
ChannelId: c.GetInt("channel_id"),
}
if midjResponse.Code == 4 || midjResponse.Code == 24 {
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
//非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
midjourneyTask.FailReason = midjResponse.Description
consumeQuota = false
}
if midjResponse.Code == 21 { //21-任务已存在(处理中或者有结果了)
// 将 properties 转换为一个 map
properties, ok := midjResponse.Properties.(map[string]interface{})
if ok {
imageUrl, ok1 := properties["imageUrl"].(string)
status, ok2 := properties["status"].(string)
if ok1 && ok2 {
midjourneyTask.ImageUrl = imageUrl
midjourneyTask.Status = status
if status == "SUCCESS" {
midjourneyTask.Progress = "100%"
midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond)
midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond)
midjResponse.Code = 1
}
}
}
//修改返回值
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
responseBody = []byte(newBody)
}
err = midjourneyTask.Insert()
if err != nil {
return &MidjourneyResponse{
@ -399,6 +474,13 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
Description: "insert_midjourney_task_failed",
}
}
if midjResponse.Code == 22 { //22-排队中,说明任务已存在
//修改返回值
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
responseBody = []byte(newBody)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {

View File

@ -26,6 +26,8 @@ const (
RelayModeImagesGenerations
RelayModeEdits
RelayModeMidjourneyImagine
RelayModeMidjourneyDescribe
RelayModeMidjourneyBlend
RelayModeMidjourneyChange
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch