From 5c747dfee24061fcd618a9981d54db351dd3d75f Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Mon, 1 Jan 2024 22:46:05 +0800 Subject: [PATCH] =?UTF-8?q?optimize:=20MJ=20=E9=83=A8=E5=88=86=E8=B0=83?= =?UTF-8?q?=E6=95=B4=E3=80=81=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MJ 增加simple-change、list接口, 变换和重试操作区别出来,价格与绘图一样 优化图片返回 --- common/model-ratio.go | 8 +- controller/relay-mj.go | 229 +++++++++++++++++++++++++++++++---------- controller/relay.go | 12 ++- model/midjourney.go | 22 +++- router/relay-router.go | 2 + 5 files changed, 217 insertions(+), 56 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index b18ba0d..774dfe2 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -14,7 +14,7 @@ import ( // 1 === $0.002 / 1K tokens // 1 === ¥0.014 / 1k tokens var ModelRatio = map[string]float64{ - "midjourney": 50, + //"midjourney": 50, "gpt-4-gizmo-*": 15, "gpt-4": 15, "gpt-4-0314": 15, @@ -80,6 +80,12 @@ var ModelRatio = map[string]float64{ var ModelPrice = map[string]float64{ "gpt-4-gizmo-*": 0.1, + "mj_imagine": 0.1, + "mj_variation": 0.1, + "mj_reroll": 0.1, + "mj_blend": 0.1, + "mj_describe": 0.05, + "mj_upscale": 0.05, } func ModelPrice2JSONString() string { diff --git a/controller/relay-mj.go b/controller/relay-mj.go index 8d88473..ad1e7f7 100644 --- a/controller/relay-mj.go +++ b/controller/relay-mj.go @@ -57,7 +57,7 @@ type MidjourneyWithoutStatus struct { func RelayMidjourneyImage(c *gin.Context) { taskId := c.Param("id") - midjourneyTask := model.GetByMJId(taskId) + midjourneyTask := model.GetByOnlyMJId(taskId) if midjourneyTask == nil { c.JSON(400, gin.H{ "error": "midjourney_task_not_found", @@ -71,14 +71,27 @@ func RelayMidjourneyImage(c *gin.Context) { }) } defer resp.Body.Close() - data, err := io.ReadAll(resp.Body) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + if resp.StatusCode != http.StatusOK { + responseBody, _ := io.ReadAll(resp.Body) + c.JSON(resp.StatusCode, gin.H{ + "error": string(responseBody), + }) return } - c.Header("Content-Type", "image/jpeg") - //c.HeaderBar("Content-Length", string(rune(len(data)))) - c.Data(http.StatusOK, "image/jpeg", data) + // 从Content-Type头获取MIME类型 + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + // 如果无法确定内容类型,则默认为jpeg + contentType = "image/jpeg" + } + // 设置响应的内容类型 + c.Writer.Header().Set("Content-Type", contentType) + // 将图片流式传输到响应体 + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + log.Println("Failed to stream image:", err) + } + return } func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse { @@ -92,7 +105,7 @@ func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse { Result: "", } } - midjourneyTask := model.GetByMJId(midjRequest.MjId) + midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId) if midjourneyTask == nil { return &MidjourneyResponse{ Code: 4, @@ -121,16 +134,7 @@ func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse { return nil } -func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse { - taskId := c.Param("id") - originTask := model.GetByMJId(taskId) - if originTask == nil { - return &MidjourneyResponse{ - Code: 4, - Description: "task_no_found", - } - } - var midjourneyTask Midjourney +func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) { midjourneyTask.MjId = originTask.MjId midjourneyTask.Progress = originTask.Progress midjourneyTask.PromptEn = originTask.PromptEn @@ -150,14 +154,65 @@ func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse { midjourneyTask.Action = originTask.Action midjourneyTask.Description = originTask.Description midjourneyTask.Prompt = originTask.Prompt - jsonMap, err := json.Marshal(midjourneyTask) - if err != nil { - return &MidjourneyResponse{ - Code: 4, - Description: "unmarshal_response_body_failed", + return +} + +func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse { + userId := c.GetInt("id") + var err error + var respBody []byte + switch relayMode { + case RelayModeMidjourneyTaskFetch: + taskId := c.Param("id") + originTask := model.GetByMJId(userId, taskId) + if originTask == nil { + return &MidjourneyResponse{ + Code: 4, + Description: "task_no_found", + } + } + midjourneyTask := getMidjourneyTaskModel(c, originTask) + respBody, err = json.Marshal(midjourneyTask) + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "unmarshal_response_body_failed", + } + } + case RelayModeMidjourneyTaskFetchByCondition: + var condition = struct { + IDs []string `json:"ids"` + }{} + err = c.BindJSON(&condition) + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "do_request_failed", + } + } + var tasks []Midjourney + if len(condition.IDs) != 0 { + originTasks := model.GetByMJIds(userId, condition.IDs) + for _, originTask := range originTasks { + midjourneyTask := getMidjourneyTaskModel(c, originTask) + tasks = append(tasks, midjourneyTask) + } + } + if tasks == nil { + tasks = make([]Midjourney, 0) + } + respBody, err = json.Marshal(tasks) + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: "unmarshal_response_body_failed", + } } } - _, err = io.Copy(c.Writer, bytes.NewBuffer(jsonMap)) + + c.Writer.Header().Set("Content-Type", "application/json") + + _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) if err != nil { return &MidjourneyResponse{ Code: 4, @@ -167,6 +222,18 @@ func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse { return nil } +const ( + // type 1 根据 mode 价格不同 + MJSubmitActionImagine = "IMAGINE" + MJSubmitActionVariation = "VARIATION" //变换 + MJSubmitActionBlend = "BLEND" //混图 + + MJSubmitActionReroll = "REROLL" //重新生成 + // type 2 固定价格 + MJSubmitActionDescribe = "DESCRIBE" + MJSubmitActionUpscale = "UPSCALE" // 放大 +) + func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { imageModel := "midjourney" @@ -186,6 +253,9 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { } } } + + action := midjRequest.Action + if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 if midjRequest.Prompt == "" { return &MidjourneyResponse{ @@ -199,7 +269,44 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { } else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 midjRequest.Action = "BLEND" } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果 - originTask := model.GetByMJId(midjRequest.TaskId) + mjId := "" + if relayMode == RelayModeMidjourneyChange { + if midjRequest.TaskId == "" { + return &MidjourneyResponse{ + Code: 4, + Description: "taskId_is_required", + } + } else if midjRequest.Action == "" { + return &MidjourneyResponse{ + Code: 4, + Description: "action_is_required", + } + } else if midjRequest.Index == 0 { + return &MidjourneyResponse{ + Code: 4, + Description: "index_can_only_be_1_2_3_4", + } + } + action = midjRequest.Action + mjId = midjRequest.TaskId + } else if relayMode == RelayModeMidjourneySimpleChange { + if midjRequest.Content == "" { + return &MidjourneyResponse{ + Code: 4, + Description: "content_is_required", + } + } + params := convertSimpleChangeParams(midjRequest.Content) + if params == nil { + return &MidjourneyResponse{ + Code: 4, + Description: "content_parse_failed", + } + } + mjId = params.ID + action = params.Action + } + originTask := model.GetByMJId(userId, mjId) if originTask == nil { return &MidjourneyResponse{ Code: 4, @@ -229,23 +336,6 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL()) } midjRequest.Prompt = originTask.Prompt - } else if relayMode == RelayModeMidjourneyChange { - if midjRequest.TaskId == "" { - return &MidjourneyResponse{ - Code: 4, - Description: "taskId_is_required", - } - } else if midjRequest.Action == "" { - return &MidjourneyResponse{ - Code: 4, - Description: "action_is_required", - } - } else if midjRequest.Index == 0 { - return &MidjourneyResponse{ - Code: 4, - Description: "index_can_only_be_1_2_3_4", - } - } } // map model name @@ -293,17 +383,17 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { requestBody = c.Request.Body } - modelRatio := common.GetModelRatio(imageModel) + modelPrice := common.GetModelPrice("mj_" + strings.ToLower(action)) groupRatio := common.GetGroupRatio(group) - ratio := modelRatio * groupRatio + ratio := modelPrice * groupRatio userQuota, err := model.CacheGetUserQuota(userId) - - sizeRatio := 1.0 - if midjRequest.Action == "UPSCALE" { - sizeRatio = 0.2 + if err != nil { + return &MidjourneyResponse{ + Code: 4, + Description: err.Error(), + } } - - quota := int(ratio * sizeRatio * 1000) + quota := int(ratio * common.QuotaPerUnit) if consumeQuota && userQuota-quota < 0 { return &MidjourneyResponse{ @@ -369,7 +459,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { } if quota != 0 { tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, action) model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") @@ -423,7 +513,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { midjourneyTask := &model.Midjourney{ UserId: userId, Code: midjResponse.Code, - Action: midjRequest.Action, + Action: action, MjId: midjResponse.Result, Prompt: midjRequest.Prompt, PromptEn: "", @@ -504,3 +594,38 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { } return nil } + +type taskChangeParams struct { + ID string + Action string + Index int +} + +func convertSimpleChangeParams(content string) *taskChangeParams { + split := strings.Split(content, " ") + if len(split) != 2 { + return nil + } + + action := strings.ToLower(split[1]) + changeParams := &taskChangeParams{} + changeParams.ID = split[0] + + if action[0] == 'u' { + changeParams.Action = "UPSCALE" + } else if action[0] == 'v' { + changeParams.Action = "VARIATION" + } else if action == "r" { + changeParams.Action = "REROLL" + return changeParams + } else { + return nil + } + + index, err := strconv.Atoi(action[1:2]) + if err != nil || index < 1 || index > 4 { + return nil + } + changeParams.Index = index + return changeParams +} diff --git a/controller/relay.go b/controller/relay.go index 4cce28d..3850b2f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -95,8 +95,10 @@ const ( RelayModeMidjourneyDescribe RelayModeMidjourneyBlend RelayModeMidjourneyChange + RelayModeMidjourneySimpleChange RelayModeMidjourneyNotify RelayModeMidjourneyTaskFetch + RelayModeMidjourneyTaskFetchByCondition RelayModeAudio ) @@ -263,6 +265,7 @@ type MidjourneyRequest struct { State string `json:"state"` TaskId string `json:"taskId"` Base64Array []string `json:"base64Array"` + Content string `json:"content"` } type MidjourneyResponse struct { @@ -342,14 +345,19 @@ func RelayMidjourney(c *gin.Context) { relayMode = RelayModeMidjourneyNotify } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") { relayMode = RelayModeMidjourneyChange - } else if strings.HasPrefix(c.Request.URL.Path, "/mj/task") { + } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") { + relayMode = RelayModeMidjourneyChange + } else if strings.HasSuffix(c.Request.URL.Path, "/fetch") { relayMode = RelayModeMidjourneyTaskFetch + } else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") { + relayMode = RelayModeMidjourneyTaskFetchByCondition } + var err *MidjourneyResponse switch relayMode { case RelayModeMidjourneyNotify: err = relayMidjourneyNotify(c) - case RelayModeMidjourneyTaskFetch: + case RelayModeMidjourneyTaskFetch, RelayModeMidjourneyTaskFetchByCondition: err = relayMidjourneyTask(c, relayMode) default: err = relayMidjourneySubmit(c, relayMode) diff --git a/model/midjourney.go b/model/midjourney.go index 84d228e..85b42c3 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -96,7 +96,7 @@ func GetAllUnFinishTasks() []*Midjourney { return tasks } -func GetByMJId(mjId string) *Midjourney { +func GetByOnlyMJId(mjId string) *Midjourney { var mj *Midjourney var err error err = DB.Where("mj_id = ?", mjId).First(&mj).Error @@ -106,6 +106,26 @@ func GetByMJId(mjId string) *Midjourney { return mj } +func GetByMJId(userId int, mjId string) *Midjourney { + var mj *Midjourney + var err error + err = DB.Where("user_id = ? and mj_id = ?", userId, mjId).First(&mj).Error + if err != nil { + return nil + } + return mj +} + +func GetByMJIds(userId int, mjIds []string) []*Midjourney { + var mj []*Midjourney + var err error + err = DB.Where("user_id = ? and mj_id in (?)", userId, mjIds).Find(&mj).Error + if err != nil { + return nil + } + return mj +} + func GetMjByuId(id int) *Midjourney { var mj *Midjourney var err error diff --git a/router/relay-router.go b/router/relay-router.go index 787e6cf..fd80b30 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -49,10 +49,12 @@ func SetRelayRouter(router *gin.Engine) { { relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney) relayMjRouter.POST("/submit/change", controller.RelayMidjourney) + relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney) relayMjRouter.POST("/submit/describe", controller.RelayMidjourney) relayMjRouter.POST("/submit/blend", controller.RelayMidjourney) relayMjRouter.POST("/notify", controller.RelayMidjourney) relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney) + relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney) } //relayMjRouter.Use() }