package model import ( "database/sql/driver" "encoding/json" "one-api/constant" commonRelay "one-api/relay/common" "time" ) type TaskStatus string const ( TaskStatusNotStart TaskStatus = "NOT_START" TaskStatusSubmitted = "SUBMITTED" TaskStatusQueued = "QUEUED" TaskStatusInProgress = "IN_PROGRESS" TaskStatusFailure = "FAILURE" TaskStatusSuccess = "SUCCESS" TaskStatusUnknown = "UNKNOWN" ) type Task struct { ID int64 `json:"id" gorm:"primary_key;AUTO_INCREMENT"` CreatedAt int64 `json:"created_at" gorm:"index"` UpdatedAt int64 `json:"updated_at"` TaskID string `json:"task_id" gorm:"type:varchar(50);index"` // 第三方id,不一定有/ song id\ Task id Platform constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台 UserId int `json:"user_id" gorm:"index"` ChannelId int `json:"channel_id" gorm:"index"` Quota int `json:"quota"` Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode Status TaskStatus `json:"status" gorm:"type:varchar(20);index"` // 任务状态 FailReason string `json:"fail_reason"` SubmitTime int64 `json:"submit_time" gorm:"index"` StartTime int64 `json:"start_time" gorm:"index"` FinishTime int64 `json:"finish_time" gorm:"index"` Progress string `json:"progress" gorm:"type:varchar(20);index"` Properties Properties `json:"properties" gorm:"type:json"` Data json.RawMessage `json:"data" gorm:"type:json"` } func (t *Task) SetData(data any) { b, _ := json.Marshal(data) t.Data = json.RawMessage(b) } func (t *Task) GetData(v any) error { err := json.Unmarshal(t.Data, &v) return err } type Properties struct { Input string `json:"input"` } func (m *Properties) Scan(val interface{}) error { bytesValue, _ := val.([]byte) return json.Unmarshal(bytesValue, m) } func (m Properties) Value() (driver.Value, error) { return json.Marshal(m) } // SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 type SyncTaskQueryParams struct { Platform constant.TaskPlatform ChannelID string TaskID string UserID string Action string Status string StartTimestamp int64 EndTimestamp int64 UserIDs []int } func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task { t := &Task{ UserId: relayInfo.UserId, SubmitTime: time.Now().Unix(), Status: TaskStatusNotStart, Progress: "0%", ChannelId: relayInfo.ChannelId, Platform: platform, } return t } func TaskGetAllUserTask(userId int, startIdx int, num int, queryParams SyncTaskQueryParams) []*Task { var tasks []*Task var err error // 初始化查询构建器 query := DB.Where("user_id = ?", userId) if queryParams.TaskID != "" { query = query.Where("task_id = ?", queryParams.TaskID) } if queryParams.Action != "" { query = query.Where("action = ?", queryParams.Action) } if queryParams.Status != "" { query = query.Where("status = ?", queryParams.Status) } if queryParams.Platform != "" { query = query.Where("platform = ?", queryParams.Platform) } if queryParams.StartTimestamp != 0 { // 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析 query = query.Where("submit_time >= ?", queryParams.StartTimestamp) } if queryParams.EndTimestamp != 0 { query = query.Where("submit_time <= ?", queryParams.EndTimestamp) } // 获取数据 err = query.Omit("channel_id").Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error if err != nil { return nil } return tasks } func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*Task { var tasks []*Task var err error // 初始化查询构建器 query := DB // 添加过滤条件 if queryParams.ChannelID != "" { query = query.Where("channel_id = ?", queryParams.ChannelID) } if queryParams.Platform != "" { query = query.Where("platform = ?", queryParams.Platform) } if queryParams.UserID != "" { query = query.Where("user_id = ?", queryParams.UserID) } if len(queryParams.UserIDs) != 0 { query = query.Where("user_id in (?)", queryParams.UserIDs) } if queryParams.TaskID != "" { query = query.Where("task_id = ?", queryParams.TaskID) } if queryParams.Action != "" { query = query.Where("action = ?", queryParams.Action) } if queryParams.Status != "" { query = query.Where("status = ?", queryParams.Status) } if queryParams.StartTimestamp != 0 { query = query.Where("submit_time >= ?", queryParams.StartTimestamp) } if queryParams.EndTimestamp != 0 { query = query.Where("submit_time <= ?", queryParams.EndTimestamp) } // 获取数据 err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error if err != nil { return nil } return tasks } func GetAllUnFinishSyncTasks(limit int) []*Task { var tasks []*Task var err error // get all tasks progress is not 100% err = DB.Where("progress != ?", "100%").Limit(limit).Order("id").Find(&tasks).Error if err != nil { return nil } return tasks } func GetByOnlyTaskId(taskId string) (*Task, bool, error) { if taskId == "" { return nil, false, nil } var task *Task var err error err = DB.Where("task_id = ?", taskId).First(&task).Error exist, err := RecordExist(err) if err != nil { return nil, false, err } return task, exist, err } func GetByTaskId(userId int, taskId string) (*Task, bool, error) { if taskId == "" { return nil, false, nil } var task *Task var err error err = DB.Where("user_id = ? and task_id = ?", userId, taskId). First(&task).Error exist, err := RecordExist(err) if err != nil { return nil, false, err } return task, exist, err } func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) { if len(taskIds) == 0 { return nil, nil } var task []*Task var err error err = DB.Where("user_id = ? and task_id in (?)", userId, taskIds). Find(&task).Error if err != nil { return nil, err } return task, nil } func TaskUpdateProgress(id int64, progress string) error { return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error } func (Task *Task) Insert() error { var err error err = DB.Create(Task).Error return err } func (Task *Task) Update() error { var err error err = DB.Save(Task).Error return err } func TaskBulkUpdate(TaskIds []string, params map[string]any) error { if len(TaskIds) == 0 { return nil } return DB.Model(&Task{}). Where("task_id in (?)", TaskIds). Updates(params).Error } func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error { if len(taskIDs) == 0 { return nil } return DB.Model(&Task{}). Where("id in (?)", taskIDs). Updates(params).Error } func TaskBulkUpdateByID(ids []int64, params map[string]any) error { if len(ids) == 0 { return nil } return DB.Model(&Task{}). Where("id in (?)", ids). Updates(params).Error } type TaskQuotaUsage struct { Mode string `json:"mode"` Count float64 `json:"count"` } func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) { query := DB.Model(Task{}) // 添加过滤条件 if queryParams.ChannelID != "" { query = query.Where("channel_id = ?", queryParams.ChannelID) } if queryParams.UserID != "" { query = query.Where("user_id = ?", queryParams.UserID) } if len(queryParams.UserIDs) != 0 { query = query.Where("user_id in (?)", queryParams.UserIDs) } if queryParams.TaskID != "" { query = query.Where("task_id = ?", queryParams.TaskID) } if queryParams.Action != "" { query = query.Where("action = ?", queryParams.Action) } if queryParams.Status != "" { query = query.Where("status = ?", queryParams.Status) } if queryParams.StartTimestamp != 0 { query = query.Where("submit_time >= ?", queryParams.StartTimestamp) } if queryParams.EndTimestamp != 0 { query = query.Where("submit_time <= ?", queryParams.EndTimestamp) } err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error return stat, err }