diff --git a/controller/midjourney.go b/controller/midjourney.go index 677f916..6791442 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -16,8 +16,10 @@ import ( "time" ) -func UpdateMidjourneyTask() { +/*func UpdateMidjourneyTask() { //revocer + //imageModel := "midjourney" + ctx := context.TODO() imageModel := "midjourney" defer func() { if err := recover(); err != nil { @@ -28,27 +30,28 @@ func UpdateMidjourneyTask() { time.Sleep(time.Duration(15) * time.Second) tasks := model.GetAllUnFinishTasks() if len(tasks) != 0 { - log.Printf("检测到未完成的任务数有: %v", len(tasks)) + common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) for _, task := range tasks { - log.Printf("未完成的任务信息: %v", task) + common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task)) midjourneyChannel, err := model.GetChannelById(task.ChannelId, true) if err != nil { - log.Printf("UpdateMidjourneyTask: %v", err) + common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err)) task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId) task.Status = "FAILURE" task.Progress = "100%" err := task.Update() if err != nil { - log.Printf("UpdateMidjourneyTask error: %v", err) + common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) + continue } continue } requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId) - log.Printf("requestUrl: %s", requestUrl) + common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl)) req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte(""))) if err != nil { - log.Printf("UpdateMidjourneyTask error: %v", err) + common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err)) continue } @@ -111,7 +114,7 @@ func UpdateMidjourneyTask() { task.Status = responseItem.Status task.FailReason = responseItem.FailReason if task.Progress != "100%" && responseItem.FailReason != "" { - log.Println(task.MjId + " 构建失败," + task.FailReason) + common.LogWarn(task.MjId + " 构建失败," + task.FailReason) task.Progress = "100%" err = model.CacheUpdateUserQuota(task.UserId) if err != nil { @@ -126,8 +129,8 @@ func UpdateMidjourneyTask() { if err != nil { log.Println("fail to increase user quota") } - logContent := fmt.Sprintf("%s 构图失败,补偿 %s", task.MjId, common.LogQuota(quota)) - model.RecordLog(task.UserId, 1, logContent) + logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } } } @@ -142,6 +145,180 @@ func UpdateMidjourneyTask() { } } } +*/ + +func UpdateMidjourneyTaskBulk() { + //revocer + defer func() { + if err := recover(); err != nil { + log.Printf("UpdateMidjourneyTask panic: %v", err) + } + }() + imageModel := "midjourney" + ctx := context.TODO() + for { + time.Sleep(time.Duration(15) * time.Second) + + tasks := model.GetAllUnFinishTasks() + if len(tasks) == 0 { + continue + } + + common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) + taskChannelM := make(map[int][]string) + taskM := make(map[string]*model.Midjourney) + for _, task := range tasks { + if task.MjId == "" { + continue + } + taskM[task.MjId] = task + taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId) + } + if len(taskChannelM) == 0 { + continue + } + + for channelId, taskIds := range taskChannelM { + common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + continue + } + midjourneyChannel, err := model.CacheGetChannel(channelId) + if err != nil { + common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err)) + err := model.MjBulkUpdate(taskIds, map[string]any{ + "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + if err != nil { + common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) + } + continue + } + requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL) + + body, _ := json.Marshal(map[string]any{ + "ids": taskIds, + }) + req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body)) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err)) + continue + } + // 设置超时时间 + timeout := time.Second * 5 + ctx, cancel := context.WithTimeout(context.Background(), timeout) + // 使用带有超时的 context 创建新的请求 + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("mj-api-secret", midjourneyChannel.Key) + resp, err := httpClient.Do(req) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) + continue + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) + continue + } + var responseItems []Midjourney + err = json.Unmarshal(responseBody, &responseItems) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v", err)) + continue + } + resp.Body.Close() + req.Body.Close() + cancel() + + for _, responseItem := range responseItems { + task := taskM[responseItem.MjId] + if !checkMjTaskNeedUpdate(task, responseItem) { + continue + } + + task.Code = 1 + task.Progress = responseItem.Progress + task.PromptEn = responseItem.PromptEn + task.State = responseItem.State + task.SubmitTime = responseItem.SubmitTime + task.StartTime = responseItem.StartTime + task.FinishTime = responseItem.FinishTime + task.ImageUrl = responseItem.ImageUrl + task.Status = responseItem.Status + task.FailReason = responseItem.FailReason + if task.Progress != "100%" && responseItem.FailReason != "" { + common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) + task.Progress = "100%" + err = model.CacheUpdateUserQuota(task.UserId) + if err != nil { + common.LogError(ctx, "error update user quota cache: "+err.Error()) + } else { + modelRatio := common.GetModelRatio(imageModel) + groupRatio := common.GetGroupRatio("default") + ratio := modelRatio * groupRatio + quota := int(ratio * 1 * 1000) + if quota != 0 { + err = model.IncreaseUserQuota(task.UserId, quota) + if err != nil { + common.LogError(ctx, "fail to increase user quota: "+err.Error()) + } + logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } + } + } + err = task.Update() + if err != nil { + common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) + } + } + } + } +} + +func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool { + if oldTask.Code != 1 { + return true + } + if oldTask.Progress != newTask.Progress { + return true + } + if oldTask.PromptEn != newTask.PromptEn { + return true + } + if oldTask.State != newTask.State { + return true + } + if oldTask.SubmitTime != newTask.SubmitTime { + return true + } + if oldTask.StartTime != newTask.StartTime { + return true + } + if oldTask.FinishTime != newTask.FinishTime { + return true + } + if oldTask.ImageUrl != newTask.ImageUrl { + return true + } + if oldTask.Status != newTask.Status { + return true + } + if oldTask.FailReason != newTask.FailReason { + return true + } + if oldTask.FinishTime != newTask.FinishTime { + return true + } + if oldTask.Progress != "100%" && newTask.FailReason != "" { + return true + } + + return false +} func GetAllMidjourney(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) diff --git a/main.go b/main.go index 470d7e0..3cfb29f 100644 --- a/main.go +++ b/main.go @@ -81,7 +81,7 @@ func main() { } go controller.AutomaticallyTestChannels(frequency) } - go controller.UpdateMidjourneyTask() + go controller.UpdateMidjourneyTaskBulk() if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { common.BatchUpdateEnabled = true common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") diff --git a/model/cache.go b/model/cache.go index a50c972..c575ad9 100644 --- a/model/cache.go +++ b/model/cache.go @@ -133,6 +133,7 @@ func CacheIsUserEnabled(userId int) (bool, error) { } var group2model2channels map[string]map[string][]*Channel +var channelsIDM map[int]*Channel var channelSyncLock sync.RWMutex func InitChannelCache() { @@ -149,10 +150,12 @@ func InitChannelCache() { groups[ability.Group] = true } newGroup2model2channels := make(map[string]map[string][]*Channel) + newChannelsIDM := make(map[int]*Channel) for group := range groups { newGroup2model2channels[group] = make(map[string][]*Channel) } for _, channel := range channels { + newChannelsIDM[channel.Id] = channel groups := strings.Split(channel.Group, ",") for _, group := range groups { models := strings.Split(channel.Models, ",") @@ -177,6 +180,7 @@ func InitChannelCache() { channelSyncLock.Lock() group2model2channels = newGroup2model2channels + channelsIDM = newChannelsIDM channelSyncLock.Unlock() common.SysLog("channels synced from database") } @@ -217,3 +221,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error idx := rand.Intn(endIdx) return channels[idx], nil } + +func CacheGetChannel(id int) (*Channel, error) { + if !common.MemoryCacheEnabled { + return GetChannelById(id, true) + } + channelSyncLock.RLock() + defer channelSyncLock.RUnlock() + + c, ok := channelsIDM[id] + if !ok { + return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id)) + } + return c, nil +} diff --git a/model/midjourney.go b/model/midjourney.go index e24d8fc..84d228e 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -131,3 +131,9 @@ func (midjourney *Midjourney) Update() error { err = DB.Save(midjourney).Error return err } + +func MjBulkUpdate(taskIDs []string, params map[string]any) error { + return DB.Model(&Midjourney{}). + Where("mj_id in (?)", taskIDs). + Updates(params).Error +}