mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-12 20:23:46 +08:00
merge v4.1.3
This commit is contained in:
@@ -35,9 +35,10 @@ type Service struct {
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
@@ -45,6 +46,7 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
|
||||
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
uploadManager: manager,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,32 +124,23 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
return "", errors.New("insufficient of power")
|
||||
}
|
||||
|
||||
// 更新用户算力
|
||||
tx := s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
|
||||
// 记录算力变化日志
|
||||
if tx.Error == nil && tx.RowsAffected > 0 {
|
||||
var u model.User
|
||||
s.db.Where("id", user.Id).First(&u)
|
||||
s.db.Create(&model.PowerLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
Type: types.PowerConsume,
|
||||
Amount: task.Power,
|
||||
Balance: u.Power,
|
||||
Mark: types.PowerSub,
|
||||
Model: "dall-e-3",
|
||||
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
// 扣减算力
|
||||
err := s.userService.DecreasePower(int(user.Id), task.Power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "dall-e-3",
|
||||
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with decrease power: %v", err)
|
||||
}
|
||||
|
||||
// get image generation API KEY
|
||||
var apiKey model.ApiKey
|
||||
tx = s.db.Where("type", "dalle").
|
||||
err = s.db.Where("type", "dalle").
|
||||
Where("enabled", true).
|
||||
Order("last_used_at ASC").First(&apiKey)
|
||||
if tx.Error != nil {
|
||||
return "", fmt.Errorf("no available DALL-E api key: %v", tx.Error)
|
||||
Order("last_used_at ASC").First(&apiKey).Error
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("no available DALL-E api key: %v", err)
|
||||
}
|
||||
|
||||
var res imgRes
|
||||
@@ -181,13 +174,13 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
// update the api key last use time
|
||||
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
// update task progress
|
||||
tx = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
|
||||
err = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
|
||||
"progress": 100,
|
||||
"org_url": res.Data[0].Url,
|
||||
"prompt": prompt,
|
||||
})
|
||||
if tx.Error != nil {
|
||||
return "", fmt.Errorf("err with update database: %v", tx.Error)
|
||||
}).Error
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("err with update database: %v", err)
|
||||
}
|
||||
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
|
||||
|
||||
@@ -82,14 +82,21 @@ func (s *Service) Run() {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
r, err := s.Create(task)
|
||||
var r RespVo
|
||||
if task.Type == 3 && task.SongId != "" { // 歌曲拼接
|
||||
r, err = s.Merge(task)
|
||||
} else if task.Type == 4 && task.AudioURL != "" { // 上传歌曲
|
||||
r, err = s.Upload(task)
|
||||
} else { // 歌曲创作
|
||||
r, err = s.Create(task)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Errorf("create task with error: %v", err)
|
||||
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"err_msg": err.Error(),
|
||||
"progress": service.FailTaskProgress,
|
||||
})
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -138,7 +145,7 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
|
||||
}
|
||||
|
||||
var res RespVo
|
||||
apiURL := fmt.Sprintf("%s/task/suno/v1/submit/music", apiKey.ApiURL)
|
||||
apiURL := fmt.Sprintf("%s/suno/submit/music", apiKey.ApiURL)
|
||||
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
@@ -164,6 +171,97 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *Service) Merge(task types.SunoTask) (RespVo, error) {
|
||||
// 读取 API KEY
|
||||
var apiKey model.ApiKey
|
||||
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
|
||||
if task.Channel != "" {
|
||||
session = session.Where("api_url", task.Channel)
|
||||
}
|
||||
tx := session.Order("last_used_at DESC").First(&apiKey)
|
||||
if tx.Error != nil {
|
||||
return RespVo{}, errors.New("no available API KEY for Suno")
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"clip_id": task.SongId,
|
||||
"is_infill": false,
|
||||
}
|
||||
|
||||
var res RespVo
|
||||
apiURL := fmt.Sprintf("%s/suno/submit/concat", apiKey.ApiURL)
|
||||
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(reqBody).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||
}
|
||||
|
||||
if res.Code != "success" {
|
||||
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
|
||||
}
|
||||
// update the last_use_at for api key
|
||||
apiKey.LastUsedAt = time.Now().Unix()
|
||||
session.Updates(&apiKey)
|
||||
res.Channel = apiKey.ApiURL
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *Service) Upload(task types.SunoTask) (RespVo, error) {
|
||||
// 读取 API KEY
|
||||
var apiKey model.ApiKey
|
||||
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
|
||||
if task.Channel != "" {
|
||||
session = session.Where("api_url", task.Channel)
|
||||
}
|
||||
tx := session.Order("last_used_at DESC").First(&apiKey)
|
||||
if tx.Error != nil {
|
||||
return RespVo{}, errors.New("no available API KEY for Suno")
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"url": task.AudioURL,
|
||||
}
|
||||
|
||||
var res RespVo
|
||||
apiURL := fmt.Sprintf("%s/suno/uploads/audio-url", apiKey.ApiURL)
|
||||
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(reqBody).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||
}
|
||||
|
||||
if r.StatusCode != 200 {
|
||||
return RespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||
}
|
||||
|
||||
if res.Code != "success" {
|
||||
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
|
||||
}
|
||||
// update the last_use_at for api key
|
||||
apiKey.LastUsedAt = time.Now().Unix()
|
||||
session.Updates(&apiKey)
|
||||
res.Channel = apiKey.ApiURL
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running Suno task notify checking ...")
|
||||
@@ -185,7 +283,7 @@ func (s *Service) CheckTaskNotify() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) DownloadImages() {
|
||||
func (s *Service) DownloadFiles() {
|
||||
go func() {
|
||||
var items []model.SunoJob
|
||||
for {
|
||||
@@ -331,15 +429,15 @@ type QueryRespVo struct {
|
||||
func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) {
|
||||
// 读取 API KEY
|
||||
var apiKey model.ApiKey
|
||||
tx := s.db.Session(&gorm.Session{}).Where("type", "suno").
|
||||
err := s.db.Session(&gorm.Session{}).Where("type", "suno").
|
||||
Where("api_url", channel).
|
||||
Where("enabled", true).
|
||||
Order("last_used_at DESC").First(&apiKey)
|
||||
if tx.Error != nil {
|
||||
Order("last_used_at DESC").First(&apiKey).Error
|
||||
if err != nil {
|
||||
return QueryRespVo{}, errors.New("no available API KEY for Suno")
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/task/suno/v1/fetch/%s", apiKey.ApiURL, taskId)
|
||||
apiURL := fmt.Sprintf("%s/suno/fetch/%s", apiKey.ApiURL, taskId)
|
||||
var res QueryRespVo
|
||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
|
||||
|
||||
|
||||
83
api/service/user_service.go
Normal file
83
api/service/user_service.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"gorm.io/gorm"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserService struct {
|
||||
db *gorm.DB
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func NewUserService(db *gorm.DB) *UserService {
|
||||
return &UserService{db: db, lock: sync.Mutex{}}
|
||||
}
|
||||
|
||||
// IncreasePower 增加用户算力
|
||||
func (s *UserService) IncreasePower(userId int, power int, log model.PowerLog) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
tx := s.db.Begin()
|
||||
err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power + ?", power)).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
var user model.User
|
||||
tx.Where("id", userId).First(&user)
|
||||
err = tx.Create(&model.PowerLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
Type: log.Type,
|
||||
Amount: power,
|
||||
Balance: user.Power,
|
||||
Mark: types.PowerAdd,
|
||||
Model: log.Model,
|
||||
Remark: log.Remark,
|
||||
CreatedAt: time.Now(),
|
||||
}).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecreasePower 减少用户算力
|
||||
func (s *UserService) DecreasePower(userId int, power int, log model.PowerLog) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
tx := s.db.Begin()
|
||||
err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("扣减算力失败:%v", err)
|
||||
}
|
||||
var user model.User
|
||||
tx.Where("id", userId).First(&user)
|
||||
err = tx.Create(&model.PowerLog{
|
||||
UserId: user.Id,
|
||||
Username: user.Username,
|
||||
Type: log.Type,
|
||||
Amount: power,
|
||||
Balance: user.Power,
|
||||
Mark: types.PowerSub,
|
||||
Model: log.Model,
|
||||
Remark: log.Remark,
|
||||
CreatedAt: time.Now(),
|
||||
}).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("记录算力日志失败:%v", err)
|
||||
}
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
330
api/service/video/luma.go
Normal file
330
api/service/video/luma.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package video
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service"
|
||||
"geekai/service/oss"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type Service struct {
|
||||
httpClient *req.Client
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
taskQueue *store.RedisQueue
|
||||
notifyQueue *store.RedisQueue
|
||||
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
|
||||
}
|
||||
|
||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
|
||||
return &Service{
|
||||
httpClient: req.C().SetTimeout(time.Minute * 3),
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
|
||||
notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli),
|
||||
Clients: types.NewLMap[uint, *types.WsClient](),
|
||||
uploadManager: manager,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) PushTask(task types.VideoTask) {
|
||||
logger.Infof("add a new Video task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
// 将数据库中未提交的人物加载到队列
|
||||
var jobs []model.VideoJob
|
||||
s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
|
||||
for _, v := range jobs {
|
||||
var params types.VideoParams
|
||||
if err := utils.JsonDecode(v.Params, ¶ms); err != nil {
|
||||
logger.Errorf("unmarshal params failed: %v", err)
|
||||
continue
|
||||
}
|
||||
s.PushTask(types.VideoTask{
|
||||
Id: v.Id,
|
||||
Channel: v.Channel,
|
||||
UserId: v.UserId,
|
||||
Type: v.Type,
|
||||
TaskId: v.TaskId,
|
||||
Prompt: v.Prompt,
|
||||
Params: params,
|
||||
})
|
||||
}
|
||||
logger.Info("Starting Video job consumer...")
|
||||
go func() {
|
||||
for {
|
||||
var task types.VideoTask
|
||||
err := s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
var r LumaRespVo
|
||||
r, err = s.LumaCreate(task)
|
||||
if err != nil {
|
||||
logger.Errorf("create task with error: %v", err)
|
||||
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"err_msg": err.Error(),
|
||||
"progress": service.FailTaskProgress,
|
||||
"cover_url": "/images/failed.jpg",
|
||||
}).Error
|
||||
if err != nil {
|
||||
logger.Errorf("update task with error: %v", err)
|
||||
}
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新任务信息
|
||||
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
|
||||
"task_id": r.Id,
|
||||
"channel": r.Channel,
|
||||
"prompt_ext": r.Prompt,
|
||||
}).Error
|
||||
if err != nil {
|
||||
logger.Errorf("update task with error: %v", err)
|
||||
s.PushTask(task)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type LumaRespVo struct {
|
||||
Id string `json:"id"`
|
||||
Prompt string `json:"prompt"`
|
||||
State string `json:"state"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Video interface{} `json:"video"`
|
||||
Liked interface{} `json:"liked"`
|
||||
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
|
||||
// 读取 API KEY
|
||||
var apiKey model.ApiKey
|
||||
session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
|
||||
if task.Channel != "" {
|
||||
session = session.Where("api_url", task.Channel)
|
||||
}
|
||||
tx := session.Order("last_used_at DESC").First(&apiKey)
|
||||
if tx.Error != nil {
|
||||
return LumaRespVo{}, errors.New("no available API KEY for Luma")
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"user_prompt": task.Prompt,
|
||||
"expand_prompt": task.Params.PromptOptimize,
|
||||
"loop": task.Params.Loop,
|
||||
"image_url": task.Params.StartImgURL,
|
||||
"image_end_url": task.Params.EndImgURL,
|
||||
}
|
||||
var res LumaRespVo
|
||||
apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
|
||||
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
|
||||
r, err := req.C().R().
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(reqBody).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||
}
|
||||
|
||||
if r.StatusCode != 200 && r.StatusCode != 201 {
|
||||
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return LumaRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||
}
|
||||
|
||||
// update the last_use_at for api key
|
||||
apiKey.LastUsedAt = time.Now().Unix()
|
||||
session.Updates(&apiKey)
|
||||
res.Channel = apiKey.ApiURL
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running Suno task notify checking ...")
|
||||
for {
|
||||
var message service.NotifyMessage
|
||||
err := s.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) DownloadFiles() {
|
||||
go func() {
|
||||
var items []model.VideoJob
|
||||
for {
|
||||
res := s.db.Where("progress", 102).Find(&items)
|
||||
if res.Error != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, v := range items {
|
||||
if v.WaterURL == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("try download video: %s", v.WaterURL)
|
||||
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true)
|
||||
if err != nil {
|
||||
logger.Errorf("download video with error: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.Infof("download video success: %s", videoURL)
|
||||
v.WaterURL = videoURL
|
||||
|
||||
if v.VideoURL != "" {
|
||||
logger.Infof("try download no water video: %s", v.VideoURL)
|
||||
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
|
||||
if err != nil {
|
||||
logger.Errorf("download video with error: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
logger.Info("download no water video success: %s", videoURL)
|
||||
v.VideoURL = videoURL
|
||||
v.Progress = 100
|
||||
s.db.Updates(&v)
|
||||
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// SyncTaskProgress 异步拉取任务
|
||||
func (s *Service) SyncTaskProgress() {
|
||||
go func() {
|
||||
var jobs []model.VideoJob
|
||||
for {
|
||||
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
|
||||
if res.Error != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, job := range jobs {
|
||||
task, err := s.QueryLumaTask(job.TaskId, job.Channel)
|
||||
if err != nil {
|
||||
logger.Errorf("query task with error: %v", err)
|
||||
// 更新任务信息
|
||||
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||||
"progress": service.FailTaskProgress, // 102 表示资源未下载完成,
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debugf("task: %+v", task)
|
||||
if task.State == "completed" { // 更新任务信息
|
||||
data := map[string]interface{}{
|
||||
"progress": 102, // 102 表示资源未下载完成,
|
||||
"water_url": task.Video.Url,
|
||||
"raw_data": utils.JsonEncode(task),
|
||||
"prompt_ext": task.Prompt,
|
||||
}
|
||||
if task.Video.DownloadUrl != "" {
|
||||
data["video_url"] = task.Video.DownloadUrl
|
||||
}
|
||||
err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error
|
||||
if err != nil {
|
||||
logger.Errorf("更新数据库失败:%v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type LumaTaskVo struct {
|
||||
Id string `json:"id"`
|
||||
Liked interface{} `json:"liked"`
|
||||
State string `json:"state"`
|
||||
Video struct {
|
||||
Url string `json:"url"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
DownloadUrl string `json:"download_url"`
|
||||
} `json:"video"`
|
||||
Prompt string `json:"prompt"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
|
||||
}
|
||||
|
||||
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
|
||||
// 读取 API KEY
|
||||
var apiKey model.ApiKey
|
||||
err := s.db.Session(&gorm.Session{}).Where("type", "luma").
|
||||
Where("api_url", channel).
|
||||
Where("enabled", true).
|
||||
Order("last_used_at DESC").First(&apiKey).Error
|
||||
if err != nil {
|
||||
return LumaTaskVo{}, errors.New("no available API KEY for Luma")
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
|
||||
var res LumaTaskVo
|
||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
|
||||
|
||||
if err != nil {
|
||||
return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
if r.StatusCode != 200 {
|
||||
return LumaTaskVo{}, fmt.Errorf("API 返回失败:%v", r.String())
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return LumaTaskVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
@@ -81,54 +81,6 @@ func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (ms
|
||||
// 自动将 VIP 会员的算力补充到每月赠送的最大值
|
||||
func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (msg string) {
|
||||
logger.Info("开始进行月底账号盘点...")
|
||||
var users []model.User
|
||||
res := e.db.Where("vip", 1).Where("status", 1).Find(&users)
|
||||
if res.Error != nil {
|
||||
return "No vip users found"
|
||||
}
|
||||
|
||||
var sysConfig model.Config
|
||||
res = e.db.Where("marker", "system").First(&sysConfig)
|
||||
if res.Error != nil {
|
||||
return "error with get system config: " + res.Error.Error()
|
||||
}
|
||||
|
||||
var config types.SystemConfig
|
||||
err := utils.JsonDecode(sysConfig.Config, &config)
|
||||
if err != nil {
|
||||
return "error with decode system config: " + err.Error()
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
// 处理过期的 VIP
|
||||
if u.ExpiredTime > 0 && u.ExpiredTime <= time.Now().Unix() {
|
||||
u.Vip = false
|
||||
e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("vip", false)
|
||||
continue
|
||||
}
|
||||
if u.Power < config.VipMonthPower {
|
||||
power := config.VipMonthPower - u.Power
|
||||
// update user
|
||||
tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", power))
|
||||
// 记录算力变动日志
|
||||
if tx.Error == nil {
|
||||
var user model.User
|
||||
e.db.Where("id", u.Id).First(&user)
|
||||
e.db.Create(&model.PowerLog{
|
||||
UserId: u.Id,
|
||||
Username: u.Username,
|
||||
Type: types.PowerRecharge,
|
||||
Amount: power,
|
||||
Mark: types.PowerAdd,
|
||||
Balance: user.Power,
|
||||
Model: "系统盘点",
|
||||
Remark: fmt.Sprintf("VIP会员每月算力派发,:%d", config.VipMonthPower),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.Info("月底盘点完成!")
|
||||
return "success"
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user