mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	luma create and list api is ready
This commit is contained in:
		@@ -8,6 +8,7 @@
 | 
			
		||||
* 功能优化:Suno 支持合成完整歌曲,和上传自己的音乐作品进行二次创作
 | 
			
		||||
* Bug修复:手机端角色和模型选择不生效
 | 
			
		||||
* Bug修复:用户登录过期之后聊天页面出现大量报错,需要刷新页面才能正常
 | 
			
		||||
* 功能优化:优化聊天页面 Websocket 断线重连代码,提高用户体验
 | 
			
		||||
* 功能新增:支持 Luma 文生视频功能
 | 
			
		||||
 | 
			
		||||
## v4.1.2
 | 
			
		||||
 
 | 
			
		||||
@@ -150,8 +150,9 @@ type SystemConfig struct {
 | 
			
		||||
	MjPower       int `json:"mj_power,omitempty"`        // MJ 绘画消耗算力
 | 
			
		||||
	MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
 | 
			
		||||
	SdPower       int `json:"sd_power,omitempty"`        // SD 绘画消耗算力
 | 
			
		||||
	DallPower     int `json:"dall_power,omitempty"`      // DALLE3 绘图消耗算力
 | 
			
		||||
	DallPower     int `json:"dall_power,omitempty"`      // DALL-E-3 绘图消耗算力
 | 
			
		||||
	SunoPower     int `json:"suno_power,omitempty"`      // Suno 生成歌曲消耗算力
 | 
			
		||||
	LumaPower     int `json:"luma_power,omitempty"`      // Luma 生成视频消耗算力
 | 
			
		||||
 | 
			
		||||
	WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -166,8 +166,8 @@ func (h *SunoHandler) Create(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
func (h *SunoHandler) List(c *gin.Context) {
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	page := h.GetInt(c, "page", 0)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 0)
 | 
			
		||||
	page := h.GetInt(c, "page", 1)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 20)
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
 | 
			
		||||
 | 
			
		||||
	// 统计总数
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										240
									
								
								api/handler/video_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										240
									
								
								api/handler/video_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,240 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/service/video"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type VideoHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	service  *video.Service
 | 
			
		||||
	uploader *oss.UploaderManager
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager) *VideoHandler {
 | 
			
		||||
	return &VideoHandler{
 | 
			
		||||
		BaseHandler: BaseHandler{
 | 
			
		||||
			App: app,
 | 
			
		||||
			DB:  db,
 | 
			
		||||
		},
 | 
			
		||||
		service:  service,
 | 
			
		||||
		uploader: uploader,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Client WebSocket 客户端,用于通知任务状态变更
 | 
			
		||||
func (h *VideoHandler) Client(c *gin.Context) {
 | 
			
		||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := h.GetInt(c, "user_id", 0)
 | 
			
		||||
	if userId == 0 {
 | 
			
		||||
		logger.Info("Invalid user ID")
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	h.service.Clients.Put(uint(userId), client)
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *VideoHandler) LumaCreate(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Prompt        string `json:"prompt"`
 | 
			
		||||
		FirstFrameImg string `json:"first_frame_img,omitempty"`
 | 
			
		||||
		EndFrameImg   string `json:"end_frame_img,omitempty"`
 | 
			
		||||
		ExpandPrompt  bool   `json:"expand_prompt,omitempty"`
 | 
			
		||||
		Loop          bool   `json:"loop,omitempty"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if data.Prompt == "" {
 | 
			
		||||
		resp.ERROR(c, "prompt is needed")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := int(h.GetLoginUserId(c))
 | 
			
		||||
	params := types.VideoParams{
 | 
			
		||||
		PromptOptimize: data.ExpandPrompt,
 | 
			
		||||
		Loop:           data.Loop,
 | 
			
		||||
		StartImgURL:    data.FirstFrameImg,
 | 
			
		||||
		EndImgURL:      data.EndFrameImg,
 | 
			
		||||
	}
 | 
			
		||||
	// 插入数据库
 | 
			
		||||
	job := model.VideoJob{
 | 
			
		||||
		UserId: userId,
 | 
			
		||||
		Type:   types.VideoLuma,
 | 
			
		||||
		Prompt: data.Prompt,
 | 
			
		||||
		Power:  h.App.SysConfig.LumaPower,
 | 
			
		||||
		Params: utils.JsonEncode(params),
 | 
			
		||||
	}
 | 
			
		||||
	tx := h.DB.Create(&job)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		resp.ERROR(c, tx.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 创建任务
 | 
			
		||||
	h.service.PushTask(types.VideoTask{
 | 
			
		||||
		Id:     job.Id,
 | 
			
		||||
		UserId: userId,
 | 
			
		||||
		Type:   types.VideoLuma,
 | 
			
		||||
		Prompt: data.Prompt,
 | 
			
		||||
		Params: params,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	// update user's power
 | 
			
		||||
	tx = h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
 | 
			
		||||
	// 记录算力变化日志
 | 
			
		||||
	if tx.Error == nil && tx.RowsAffected > 0 {
 | 
			
		||||
		user, _ := h.GetLoginUser(c)
 | 
			
		||||
		h.DB.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
			Type:      types.PowerConsume,
 | 
			
		||||
			Amount:    job.Power,
 | 
			
		||||
			Balance:   user.Power - job.Power,
 | 
			
		||||
			Mark:      types.PowerSub,
 | 
			
		||||
			Model:     "luma",
 | 
			
		||||
			Remark:    fmt.Sprintf("Luma 文生视频,任务ID:%d", job.Id),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := h.service.Clients.Get(uint(job.UserId))
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *VideoHandler) List(c *gin.Context) {
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	t := c.Query("type")
 | 
			
		||||
	page := h.GetInt(c, "page", 1)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 20)
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
 | 
			
		||||
	if t != "" {
 | 
			
		||||
		session = session.Where("type", t)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 统计总数
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.VideoJob{}).Count(&total)
 | 
			
		||||
 | 
			
		||||
	if page > 0 && pageSize > 0 {
 | 
			
		||||
		offset := (page - 1) * pageSize
 | 
			
		||||
		session = session.Offset(offset).Limit(pageSize)
 | 
			
		||||
	}
 | 
			
		||||
	var list []model.VideoJob
 | 
			
		||||
	err := session.Order("id desc").Find(&list).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 转换为 VO
 | 
			
		||||
	items := make([]vo.VideoJob, 0)
 | 
			
		||||
	for _, v := range list {
 | 
			
		||||
		var item vo.VideoJob
 | 
			
		||||
		err = utils.CopyObject(v, &item)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		item.CreatedAt = v.CreatedAt.Unix()
 | 
			
		||||
		items = append(items, item)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *VideoHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	var job model.VideoJob
 | 
			
		||||
	err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 删除任务
 | 
			
		||||
	tx := h.DB.Begin()
 | 
			
		||||
	if err := tx.Delete(&job).Error; err != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果任务未完成,或者任务失败,则恢复用户算力
 | 
			
		||||
	if job.Progress != 100 {
 | 
			
		||||
		err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		var user model.User
 | 
			
		||||
		tx.Where("id = ?", job.UserId).First(&user)
 | 
			
		||||
		err = tx.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
			Type:      types.PowerRefund,
 | 
			
		||||
			Amount:    job.Power,
 | 
			
		||||
			Balance:   user.Power,
 | 
			
		||||
			Mark:      types.PowerAdd,
 | 
			
		||||
			Model:     "luma",
 | 
			
		||||
			Remark:    fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		}).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	tx.Commit()
 | 
			
		||||
 | 
			
		||||
	// 删除文件
 | 
			
		||||
	_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
 | 
			
		||||
	_ = h.uploader.GetUploadHandler().Delete(job.VideoURL)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *VideoHandler) Publish(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	publish := h.GetBool(c, "publish")
 | 
			
		||||
	err := h.DB.Model(&model.VideoJob{}).Where("id", id).Where("user_id", userId).UpdateColumn("publish", publish).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										17
									
								
								api/main.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								api/main.go
									
									
									
									
									
								
							@@ -24,6 +24,7 @@ import (
 | 
			
		||||
	"geekai/service/sd"
 | 
			
		||||
	"geekai/service/sms"
 | 
			
		||||
	"geekai/service/suno"
 | 
			
		||||
	"geekai/service/video"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"io"
 | 
			
		||||
	"log"
 | 
			
		||||
@@ -201,6 +202,13 @@ func main() {
 | 
			
		||||
			s.CheckTaskNotify()
 | 
			
		||||
			s.DownloadFiles()
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(video.NewService),
 | 
			
		||||
		fx.Invoke(func(s *video.Service) {
 | 
			
		||||
			s.Run()
 | 
			
		||||
			s.SyncTaskProgress()
 | 
			
		||||
			s.CheckTaskNotify()
 | 
			
		||||
			s.DownloadFiles()
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		fx.Provide(payment.NewAlipayService),
 | 
			
		||||
		fx.Provide(payment.NewHuPiPay),
 | 
			
		||||
@@ -484,6 +492,15 @@ func main() {
 | 
			
		||||
			group.GET("play", h.Play)
 | 
			
		||||
			group.POST("lyric", h.Lyric)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(handler.NewVideoHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/video")
 | 
			
		||||
			group.Any("client", h.Client)
 | 
			
		||||
			group.POST("luma/create", h.LumaCreate)
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
			group.GET("remove", h.Remove)
 | 
			
		||||
			group.GET("publish", h.Publish)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(handler.NewTestHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/test")
 | 
			
		||||
 
 | 
			
		||||
@@ -242,6 +242,10 @@ func (s *Service) Upload(task types.SunoTask) (RespVo, error) {
 | 
			
		||||
		return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.StatusCode != 200 {
 | 
			
		||||
		return RespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body, _ := io.ReadAll(r.Body)
 | 
			
		||||
	err = json.Unmarshal(body, &res)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -82,8 +82,8 @@ func (s *Service) Run() {
 | 
			
		||||
				logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			var r RespVo
 | 
			
		||||
			r, err = s.CreateLuma(task)
 | 
			
		||||
			var r LumaRespVo
 | 
			
		||||
			r, err = s.LumaCreate(task)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Errorf("create task with error: %v", err)
 | 
			
		||||
				s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
@@ -96,14 +96,15 @@ func (s *Service) Run() {
 | 
			
		||||
 | 
			
		||||
			// 更新任务信息
 | 
			
		||||
			s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
				"task_id": r.Id,
 | 
			
		||||
				"channel": r.Channel,
 | 
			
		||||
				"task_id":    r.Id,
 | 
			
		||||
				"channel":    r.Channel,
 | 
			
		||||
				"prompt_ext": r.Prompt,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RespVo struct {
 | 
			
		||||
type LumaRespVo struct {
 | 
			
		||||
	Id                  string      `json:"id"`
 | 
			
		||||
	Prompt              string      `json:"prompt"`
 | 
			
		||||
	State               string      `json:"state"`
 | 
			
		||||
@@ -114,7 +115,7 @@ type RespVo struct {
 | 
			
		||||
	Channel             string      `json:"channel,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
 | 
			
		||||
func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
 | 
			
		||||
	// 读取 API KEY
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
 | 
			
		||||
@@ -123,7 +124,7 @@ func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
 | 
			
		||||
	}
 | 
			
		||||
	tx := session.Order("last_used_at DESC").First(&apiKey)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return RespVo{}, errors.New("no available API KEY for Suno")
 | 
			
		||||
		return LumaRespVo{}, errors.New("no available API KEY for Luma")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	reqBody := map[string]interface{}{
 | 
			
		||||
@@ -133,7 +134,7 @@ func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
 | 
			
		||||
		"image_url":     task.Params.StartImgURL,
 | 
			
		||||
		"image_end_url": task.Params.EndImgURL,
 | 
			
		||||
	}
 | 
			
		||||
	var res RespVo
 | 
			
		||||
	var res LumaRespVo
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
 | 
			
		||||
	logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
@@ -141,13 +142,17 @@ func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
 | 
			
		||||
		SetBody(reqBody).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
		return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.StatusCode != 200 {
 | 
			
		||||
		return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body, _ := io.ReadAll(r.Body)
 | 
			
		||||
	err = json.Unmarshal(body, &res)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
 | 
			
		||||
		return LumaRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update the last_use_at for api key
 | 
			
		||||
@@ -180,7 +185,7 @@ func (s *Service) CheckTaskNotify() {
 | 
			
		||||
 | 
			
		||||
func (s *Service) DownloadFiles() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		var items []model.SunoJob
 | 
			
		||||
		var items []model.VideoJob
 | 
			
		||||
		for {
 | 
			
		||||
			res := s.db.Where("progress", 102).Find(&items)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
@@ -188,22 +193,13 @@ func (s *Service) DownloadFiles() {
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, v := range items {
 | 
			
		||||
				// 下载图片和音频
 | 
			
		||||
				logger.Infof("try download cover image: %s", v.CoverURL)
 | 
			
		||||
				coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Errorf("download image with error: %v", err)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				logger.Infof("try download audio: %s", v.AudioURL)
 | 
			
		||||
				audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true)
 | 
			
		||||
				logger.Infof("try download video: %s", v.VideoURL)
 | 
			
		||||
				videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Errorf("download audio with error: %v", err)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				v.CoverURL = coverURL
 | 
			
		||||
				v.AudioURL = audioURL
 | 
			
		||||
				v.VideoURL = videoURL
 | 
			
		||||
				v.Progress = 100
 | 
			
		||||
				s.db.Updates(&v)
 | 
			
		||||
				s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
 | 
			
		||||
@@ -217,7 +213,7 @@ func (s *Service) DownloadFiles() {
 | 
			
		||||
// SyncTaskProgress 异步拉取任务
 | 
			
		||||
func (s *Service) SyncTaskProgress() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		var jobs []model.SunoJob
 | 
			
		||||
		var jobs []model.VideoJob
 | 
			
		||||
		for {
 | 
			
		||||
			res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
@@ -225,60 +221,14 @@ func (s *Service) SyncTaskProgress() {
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, job := range jobs {
 | 
			
		||||
				task, err := s.QueryTask(job.TaskId, job.Channel)
 | 
			
		||||
				task, err := s.QueryLumaTask(job.TaskId, job.Channel)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Errorf("query task with error: %v", err)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if task.Code != "success" {
 | 
			
		||||
					logger.Errorf("query task with error: %v", task.Message)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				logger.Debugf("task: %+v", task)
 | 
			
		||||
 | 
			
		||||
				logger.Debugf("task: %+v", task.Data.Status)
 | 
			
		||||
				// 任务完成,删除旧任务插入两条新任务
 | 
			
		||||
				if task.Data.Status == "SUCCESS" {
 | 
			
		||||
					var jobId = job.Id
 | 
			
		||||
					var flag = false
 | 
			
		||||
					tx := s.db.Begin()
 | 
			
		||||
					for _, v := range task.Data.Data {
 | 
			
		||||
						job.Id = 0
 | 
			
		||||
						job.Progress = 102 // 102 表示资源未下载完成
 | 
			
		||||
						job.Title = v.Title
 | 
			
		||||
						job.SongId = v.Id
 | 
			
		||||
						job.Duration = int(v.Metadata.Duration)
 | 
			
		||||
						job.Prompt = v.Metadata.Prompt
 | 
			
		||||
						job.Tags = v.Metadata.Tags
 | 
			
		||||
						job.ModelName = v.ModelName
 | 
			
		||||
						job.RawData = utils.JsonEncode(v)
 | 
			
		||||
						job.CoverURL = v.ImageLargeUrl
 | 
			
		||||
						job.AudioURL = v.AudioUrl
 | 
			
		||||
 | 
			
		||||
						if err = tx.Create(&job).Error; err != nil {
 | 
			
		||||
							logger.Error("create job with error: %v", err)
 | 
			
		||||
							tx.Rollback()
 | 
			
		||||
							break
 | 
			
		||||
						}
 | 
			
		||||
						flag = true
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					// 删除旧任务
 | 
			
		||||
					if flag {
 | 
			
		||||
						if err = tx.Delete(&model.SunoJob{}, "id = ?", jobId).Error; err != nil {
 | 
			
		||||
							logger.Error("create job with error: %v", err)
 | 
			
		||||
							tx.Rollback()
 | 
			
		||||
							continue
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
					tx.Commit()
 | 
			
		||||
 | 
			
		||||
				} else if task.Data.FailReason != "" {
 | 
			
		||||
					job.Progress = service.FailTaskProgress
 | 
			
		||||
					job.ErrMsg = task.Data.FailReason
 | 
			
		||||
					s.db.Updates(&job)
 | 
			
		||||
					s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			time.Sleep(time.Second * 10)
 | 
			
		||||
@@ -286,42 +236,22 @@ func (s *Service) SyncTaskProgress() {
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type QueryRespVo struct {
 | 
			
		||||
	Code    string `json:"code"`
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
	Data    struct {
 | 
			
		||||
		TaskId     string `json:"task_id"`
 | 
			
		||||
		Action     string `json:"action"`
 | 
			
		||||
		Status     string `json:"status"`
 | 
			
		||||
		FailReason string `json:"fail_reason"`
 | 
			
		||||
		SubmitTime int    `json:"submit_time"`
 | 
			
		||||
		StartTime  int    `json:"start_time"`
 | 
			
		||||
		FinishTime int    `json:"finish_time"`
 | 
			
		||||
		Progress   string `json:"progress"`
 | 
			
		||||
		Data       []struct {
 | 
			
		||||
			Id       string `json:"id"`
 | 
			
		||||
			Title    string `json:"title"`
 | 
			
		||||
			Status   string `json:"status"`
 | 
			
		||||
			Metadata struct {
 | 
			
		||||
				Tags         string      `json:"tags"`
 | 
			
		||||
				Type         string      `json:"type"`
 | 
			
		||||
				Prompt       string      `json:"prompt"`
 | 
			
		||||
				Stream       bool        `json:"stream"`
 | 
			
		||||
				Duration     float64     `json:"duration"`
 | 
			
		||||
				ErrorMessage interface{} `json:"error_message"`
 | 
			
		||||
			} `json:"metadata"`
 | 
			
		||||
			AudioUrl          string `json:"audio_url"`
 | 
			
		||||
			ImageUrl          string `json:"image_url"`
 | 
			
		||||
			VideoUrl          string `json:"video_url"`
 | 
			
		||||
			ModelName         string `json:"model_name"`
 | 
			
		||||
			DisplayName       string `json:"display_name"`
 | 
			
		||||
			ImageLargeUrl     string `json:"image_large_url"`
 | 
			
		||||
			MajorModelVersion string `json:"major_model_version"`
 | 
			
		||||
		} `json:"data"`
 | 
			
		||||
	} `json:"data"`
 | 
			
		||||
type LumaTaskVo struct {
 | 
			
		||||
	Id    string      `json:"id"`
 | 
			
		||||
	Liked interface{} `json:"liked"`
 | 
			
		||||
	State string      `json:"state"`
 | 
			
		||||
	Video struct {
 | 
			
		||||
		Url         string `json:"url"`
 | 
			
		||||
		Width       int    `json:"width"`
 | 
			
		||||
		Height      int    `json:"height"`
 | 
			
		||||
		DownloadUrl string `json:"download_url"`
 | 
			
		||||
	} `json:"video"`
 | 
			
		||||
	Prompt              string      `json:"prompt"`
 | 
			
		||||
	CreatedAt           time.Time   `json:"created_at"`
 | 
			
		||||
	EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) {
 | 
			
		||||
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
 | 
			
		||||
	// 读取 API KEY
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	tx := s.db.Session(&gorm.Session{}).Where("type", "suno").
 | 
			
		||||
@@ -329,22 +259,22 @@ func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error)
 | 
			
		||||
		Where("enabled", true).
 | 
			
		||||
		Order("last_used_at DESC").First(&apiKey)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return QueryRespVo{}, errors.New("no available API KEY for Suno")
 | 
			
		||||
		return LumaTaskVo{}, errors.New("no available API KEY for Suno")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/suno/fetch/%s", apiKey.ApiURL, taskId)
 | 
			
		||||
	var res QueryRespVo
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
 | 
			
		||||
	var res LumaTaskVo
 | 
			
		||||
	r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return QueryRespVo{}, fmt.Errorf("请求 API 失败:%v", err)
 | 
			
		||||
		return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer r.Body.Close()
 | 
			
		||||
	body, _ := io.ReadAll(r.Body)
 | 
			
		||||
	err = json.Unmarshal(body, &res)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return QueryRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
 | 
			
		||||
		return LumaTaskVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
 
 | 
			
		||||
@@ -547,7 +547,7 @@ const removeChat = function (chat) {
 | 
			
		||||
            return e1.id === e2.id
 | 
			
		||||
          })
 | 
			
		||||
          // 重置会话
 | 
			
		||||
          newChat();
 | 
			
		||||
          _newChat();
 | 
			
		||||
        }).catch(e => {
 | 
			
		||||
          ElMessage.error("操作失败:" + e.message);
 | 
			
		||||
        })
 | 
			
		||||
 
 | 
			
		||||
@@ -302,6 +302,9 @@
 | 
			
		||||
                <el-form-item label="Suno 算力" prop="suno_power">
 | 
			
		||||
                  <el-input v-model.number="system['suno_power']" placeholder="使用 Suno 生成一首音乐消耗算力"/>
 | 
			
		||||
                </el-form-item>
 | 
			
		||||
                <el-form-item label="Luma 算力" prop="luma_power">
 | 
			
		||||
                  <el-input v-model.number="system['luma_power']" placeholder="使用 Luma 生成一段视频消耗算力"/>
 | 
			
		||||
                </el-form-item>
 | 
			
		||||
              </el-tab-pane>
 | 
			
		||||
            </el-tabs>
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user