From 2a9c3ac6afb43f92c7225eacca913653dada9907 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 12 Jan 2024 13:45:52 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dmj=E9=94=99=E8=AF=AF?= =?UTF-8?q?=E8=BF=94=E8=BF=98=E8=B4=B9=E7=94=A8=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/midjourney.go | 21 ++++++++++++++++----- controller/relay-mj.go | 1 + model/midjourney.go | 11 +++++++++-- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/controller/midjourney.go b/controller/midjourney.go index 6791442..3e0faff 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -154,7 +154,7 @@ func UpdateMidjourneyTaskBulk() { log.Printf("UpdateMidjourneyTask panic: %v", err) } }() - imageModel := "midjourney" + //imageModel := "midjourney" ctx := context.TODO() for { time.Sleep(time.Duration(15) * time.Second) @@ -167,13 +167,27 @@ func UpdateMidjourneyTaskBulk() { common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) taskChannelM := make(map[int][]string) taskM := make(map[string]*model.Midjourney) + nullTaskIds := make([]int, 0) for _, task := range tasks { if task.MjId == "" { + // 统计失败的未完成任务 + nullTaskIds = append(nullTaskIds, task.Id) continue } taskM[task.MjId] = task taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId) } + if len(nullTaskIds) > 0 { + err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{ + "status": "FAILURE", + "progress": "100%", + }) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err)) + } else { + common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds)) + } + } if len(taskChannelM) == 0 { continue } @@ -256,10 +270,7 @@ func UpdateMidjourneyTaskBulk() { if err != nil { common.LogError(ctx, "error update user quota cache: "+err.Error()) } else { - modelRatio := common.GetModelRatio(imageModel) - groupRatio := common.GetGroupRatio("default") - ratio := modelRatio * groupRatio - quota := int(ratio * 1 * 1000) + quota := task.Quota if quota != 0 { err = model.IncreaseUserQuota(task.UserId, quota) if err != nil { diff --git a/controller/relay-mj.go b/controller/relay-mj.go index 30aa146..df6bfcc 100644 --- a/controller/relay-mj.go +++ b/controller/relay-mj.go @@ -544,6 +544,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { Progress: "0%", FailReason: "", ChannelId: c.GetInt("channel_id"), + Quota: quota, } if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 { diff --git a/model/midjourney.go b/model/midjourney.go index 85b42c3..0ef2e55 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -18,6 +18,7 @@ type Midjourney struct { Progress string `json:"progress"` FailReason string `json:"fail_reason"` ChannelId int `json:"channel_id"` + Quota int `json:"quota"` } // TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 @@ -152,8 +153,14 @@ func (midjourney *Midjourney) Update() error { return err } -func MjBulkUpdate(taskIDs []string, params map[string]any) error { +func MjBulkUpdate(mjIds []string, params map[string]any) error { return DB.Model(&Midjourney{}). - Where("mj_id in (?)", taskIDs). + Where("mj_id in (?)", mjIds). + Updates(params).Error +} + +func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error { + return DB.Model(&Midjourney{}). + Where("id in (?)", taskIDs). Updates(params).Error }