fix: 修复mj错误返还费用问题

This commit is contained in:
CaIon 2024-01-12 13:45:52 +08:00
parent 312417f393
commit 2a9c3ac6af
3 changed files with 26 additions and 7 deletions

View File

@ -154,7 +154,7 @@ func UpdateMidjourneyTaskBulk() {
log.Printf("UpdateMidjourneyTask panic: %v", err) log.Printf("UpdateMidjourneyTask panic: %v", err)
} }
}() }()
imageModel := "midjourney" //imageModel := "midjourney"
ctx := context.TODO() ctx := context.TODO()
for { for {
time.Sleep(time.Duration(15) * time.Second) time.Sleep(time.Duration(15) * time.Second)
@ -167,13 +167,27 @@ func UpdateMidjourneyTaskBulk() {
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
taskChannelM := make(map[int][]string) taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Midjourney) taskM := make(map[string]*model.Midjourney)
nullTaskIds := make([]int, 0)
for _, task := range tasks { for _, task := range tasks {
if task.MjId == "" { if task.MjId == "" {
// 统计失败的未完成任务
nullTaskIds = append(nullTaskIds, task.Id)
continue continue
} }
taskM[task.MjId] = task taskM[task.MjId] = task
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId) taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
} }
if len(nullTaskIds) > 0 {
err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
} else {
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
}
}
if len(taskChannelM) == 0 { if len(taskChannelM) == 0 {
continue continue
} }
@ -256,10 +270,7 @@ func UpdateMidjourneyTaskBulk() {
if err != nil { if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error()) common.LogError(ctx, "error update user quota cache: "+err.Error())
} else { } else {
modelRatio := common.GetModelRatio(imageModel) quota := task.Quota
groupRatio := common.GetGroupRatio("default")
ratio := modelRatio * groupRatio
quota := int(ratio * 1 * 1000)
if quota != 0 { if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota) err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil { if err != nil {

View File

@ -544,6 +544,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
Progress: "0%", Progress: "0%",
FailReason: "", FailReason: "",
ChannelId: c.GetInt("channel_id"), ChannelId: c.GetInt("channel_id"),
Quota: quota,
} }
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 { if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {

View File

@ -18,6 +18,7 @@ type Midjourney struct {
Progress string `json:"progress"` Progress string `json:"progress"`
FailReason string `json:"fail_reason"` FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"` ChannelId int `json:"channel_id"`
Quota int `json:"quota"`
} }
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 // TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
@ -152,8 +153,14 @@ func (midjourney *Midjourney) Update() error {
return err return err
} }
func MjBulkUpdate(taskIDs []string, params map[string]any) error { func MjBulkUpdate(mjIds []string, params map[string]any) error {
return DB.Model(&Midjourney{}). return DB.Model(&Midjourney{}).
Where("mj_id in (?)", taskIDs). Where("mj_id in (?)", mjIds).
Updates(params).Error
}
func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
return DB.Model(&Midjourney{}).
Where("id in (?)", taskIDs).
Updates(params).Error Updates(params).Error
} }