fix: 优化 mj 获取进度

This commit is contained in:
Xyfacai 2023-12-23 23:14:58 +08:00
parent 7c4719b6ee
commit fd4ef086dc
4 changed files with 212 additions and 11 deletions

View File

@ -16,8 +16,10 @@ import (
"time" "time"
) )
func UpdateMidjourneyTask() { /*func UpdateMidjourneyTask() {
//revocer //revocer
//imageModel := "midjourney"
ctx := context.TODO()
imageModel := "midjourney" imageModel := "midjourney"
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
@ -28,27 +30,28 @@ func UpdateMidjourneyTask() {
time.Sleep(time.Duration(15) * time.Second) time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks() tasks := model.GetAllUnFinishTasks()
if len(tasks) != 0 { if len(tasks) != 0 {
log.Printf("检测到未完成的任务数有: %v", len(tasks)) common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
for _, task := range tasks { for _, task := range tasks {
log.Printf("未完成的任务信息: %v", task) common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
midjourneyChannel, err := model.GetChannelById(task.ChannelId, true) midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
if err != nil { if err != nil {
log.Printf("UpdateMidjourneyTask: %v", err) common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
task.FailReason = fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", task.ChannelId) task.FailReason = fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", task.ChannelId)
task.Status = "FAILURE" task.Status = "FAILURE"
task.Progress = "100%" task.Progress = "100%"
err := task.Update() err := task.Update()
if err != nil { if err != nil {
log.Printf("UpdateMidjourneyTask error: %v", err) common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
continue
} }
continue continue
} }
requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId) requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
log.Printf("requestUrl: %s", requestUrl) common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte(""))) req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
if err != nil { if err != nil {
log.Printf("UpdateMidjourneyTask error: %v", err) common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
continue continue
} }
@ -111,7 +114,7 @@ func UpdateMidjourneyTask() {
task.Status = responseItem.Status task.Status = responseItem.Status
task.FailReason = responseItem.FailReason task.FailReason = responseItem.FailReason
if task.Progress != "100%" && responseItem.FailReason != "" { if task.Progress != "100%" && responseItem.FailReason != "" {
log.Println(task.MjId + " 构建失败," + task.FailReason) common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
task.Progress = "100%" task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId) err = model.CacheUpdateUserQuota(task.UserId)
if err != nil { if err != nil {
@ -126,8 +129,8 @@ func UpdateMidjourneyTask() {
if err != nil { if err != nil {
log.Println("fail to increase user quota") log.Println("fail to increase user quota")
} }
logContent := fmt.Sprintf("%s 构图失败,补偿 %s", task.MjId, common.LogQuota(quota)) logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, 1, logContent) model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
} }
} }
} }
@ -142,6 +145,180 @@ func UpdateMidjourneyTask() {
} }
} }
} }
*/
func UpdateMidjourneyTaskBulk() {
//revocer
defer func() {
if err := recover(); err != nil {
log.Printf("UpdateMidjourneyTask panic: %v", err)
}
}()
imageModel := "midjourney"
ctx := context.TODO()
for {
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) == 0 {
continue
}
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Midjourney)
for _, task := range tasks {
if task.MjId == "" {
continue
}
taskM[task.MjId] = task
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
}
if len(taskChannelM) == 0 {
continue
}
for channelId, taskIds := range taskChannelM {
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
continue
}
midjourneyChannel, err := model.CacheGetChannel(channelId)
if err != nil {
common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
err := model.MjBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
}
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
body, _ := json.Marshal(map[string]any{
"ids": taskIds,
})
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
// 设置超时时间
timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := httpClient.Do(req)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
var responseItems []Midjourney
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v", err))
continue
}
resp.Body.Close()
req.Body.Close()
cancel()
for _, responseItem := range responseItems {
task := taskM[responseItem.MjId]
if !checkMjTaskNeedUpdate(task, responseItem) {
continue
}
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
task.State = responseItem.State
task.SubmitTime = responseItem.SubmitTime
task.StartTime = responseItem.StartTime
task.FinishTime = responseItem.FinishTime
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio("default")
ratio := modelRatio * groupRatio
quota := int(ratio * 1 * 1000)
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
err = task.Update()
if err != nil {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
}
}
}
}
}
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool {
if oldTask.Code != 1 {
return true
}
if oldTask.Progress != newTask.Progress {
return true
}
if oldTask.PromptEn != newTask.PromptEn {
return true
}
if oldTask.State != newTask.State {
return true
}
if oldTask.SubmitTime != newTask.SubmitTime {
return true
}
if oldTask.StartTime != newTask.StartTime {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if oldTask.ImageUrl != newTask.ImageUrl {
return true
}
if oldTask.Status != newTask.Status {
return true
}
if oldTask.FailReason != newTask.FailReason {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if oldTask.Progress != "100%" && newTask.FailReason != "" {
return true
}
return false
}
func GetAllMidjourney(c *gin.Context) { func GetAllMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) p, _ := strconv.Atoi(c.Query("p"))

View File

@ -81,7 +81,7 @@ func main() {
} }
go controller.AutomaticallyTestChannels(frequency) go controller.AutomaticallyTestChannels(frequency)
} }
go controller.UpdateMidjourneyTask() go controller.UpdateMidjourneyTaskBulk()
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")

View File

@ -133,6 +133,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
} }
var group2model2channels map[string]map[string][]*Channel var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex var channelSyncLock sync.RWMutex
func InitChannelCache() { func InitChannelCache() {
@ -149,10 +150,12 @@ func InitChannelCache() {
groups[ability.Group] = true groups[ability.Group] = true
} }
newGroup2model2channels := make(map[string]map[string][]*Channel) newGroup2model2channels := make(map[string]map[string][]*Channel)
newChannelsIDM := make(map[int]*Channel)
for group := range groups { for group := range groups {
newGroup2model2channels[group] = make(map[string][]*Channel) newGroup2model2channels[group] = make(map[string][]*Channel)
} }
for _, channel := range channels { for _, channel := range channels {
newChannelsIDM[channel.Id] = channel
groups := strings.Split(channel.Group, ",") groups := strings.Split(channel.Group, ",")
for _, group := range groups { for _, group := range groups {
models := strings.Split(channel.Models, ",") models := strings.Split(channel.Models, ",")
@ -177,6 +180,7 @@ func InitChannelCache() {
channelSyncLock.Lock() channelSyncLock.Lock()
group2model2channels = newGroup2model2channels group2model2channels = newGroup2model2channels
channelsIDM = newChannelsIDM
channelSyncLock.Unlock() channelSyncLock.Unlock()
common.SysLog("channels synced from database") common.SysLog("channels synced from database")
} }
@ -217,3 +221,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
idx := rand.Intn(endIdx) idx := rand.Intn(endIdx)
return channels[idx], nil return channels[idx], nil
} }
func CacheGetChannel(id int) (*Channel, error) {
if !common.MemoryCacheEnabled {
return GetChannelById(id, true)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
c, ok := channelsIDM[id]
if !ok {
return nil, errors.New(fmt.Sprintf("当前渠道# %d已不存在", id))
}
return c, nil
}

View File

@ -131,3 +131,9 @@ func (midjourney *Midjourney) Update() error {
err = DB.Save(midjourney).Error err = DB.Save(midjourney).Error
return err return err
} }
func MjBulkUpdate(taskIDs []string, params map[string]any) error {
return DB.Model(&Midjourney{}).
Where("mj_id in (?)", taskIDs).
Updates(params).Error
}