package handler // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * Copyright 2023 The Geek-AI Authors. All rights reserved. // * Use of this source code is governed by a Apache-2.0 license // * that can be found in the LICENSE file. // * @Author yangjian102621@163.com // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( "fmt" "geekai/core" "geekai/core/types" "geekai/service/oss" "geekai/service/suno" "geekai/store/model" "geekai/store/vo" "geekai/utils" "geekai/utils/resp" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "gorm.io/gorm" "net/http" "time" ) type SunoHandler struct { BaseHandler service *suno.Service uploader *oss.UploaderManager } func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager) *SunoHandler { return &SunoHandler{ BaseHandler: BaseHandler{ App: app, DB: db, }, service: service, uploader: uploader, } } // Client WebSocket 客户端,用于通知任务状态变更 func (h *SunoHandler) Client(c *gin.Context) { ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) if err != nil { logger.Error(err) c.Abort() return } userId := h.GetInt(c, "user_id", 0) if userId == 0 { logger.Info("Invalid user ID") c.Abort() return } client := types.NewWsClient(ws) h.service.Clients.Put(uint(userId), client) logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) } func (h *SunoHandler) Create(c *gin.Context) { var data struct { Prompt string `json:"prompt"` Instrumental bool `json:"instrumental"` Lyrics string `json:"lyrics"` Model string `json:"model"` Tags string `json:"tags"` Title string `json:"title"` Type int `json:"type"` RefTaskId string `json:"ref_task_id"` // 续写的任务id ExtendSecs int `json:"extend_secs"` // 续写秒数 RefSongId string `json:"ref_song_id"` // 续写的歌曲id } if err := c.ShouldBindJSON(&data); err != nil { resp.ERROR(c, types.InvalidArgs) return } // 插入数据库 job := model.SunoJob{ UserId: int(h.GetLoginUserId(c)), Prompt: data.Prompt, Instrumental: data.Instrumental, ModelName: data.Model, Tags: data.Tags, Title: data.Title, Type: data.Type, RefSongId: data.RefSongId, RefTaskId: data.RefTaskId, ExtendSecs: data.ExtendSecs, Power: h.App.SysConfig.SunoPower, } if data.Lyrics != "" { job.Prompt = data.Lyrics } tx := h.DB.Create(&job) if tx.Error != nil { resp.ERROR(c, tx.Error.Error()) return } // 创建任务 h.service.PushTask(types.SunoTask{ Id: job.Id, UserId: job.UserId, Type: job.Type, Title: job.Title, RefTaskId: data.RefTaskId, RefSongId: data.RefSongId, ExtendSecs: data.ExtendSecs, Prompt: data.Prompt, Tags: data.Tags, Model: data.Model, Instrumental: data.Instrumental, }) // update user's power tx = h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) // 记录算力变化日志 if tx.Error == nil && tx.RowsAffected > 0 { user, _ := h.GetLoginUser(c) h.DB.Create(&model.PowerLog{ UserId: user.Id, Username: user.Username, Type: types.PowerConsume, Amount: job.Power, Balance: user.Power - job.Power, Mark: types.PowerSub, Model: job.ModelName, Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName), CreatedAt: time.Now(), }) } client := h.service.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } resp.SUCCESS(c) } func (h *SunoHandler) List(c *gin.Context) { userId := h.GetLoginUserId(c) page := h.GetInt(c, "page", 0) pageSize := h.GetInt(c, "page_size", 0) session := h.DB.Session(&gorm.Session{}).Where("user_id", userId) // 统计总数 var total int64 session.Debug().Model(&model.SunoJob{}).Count(&total) if page > 0 && pageSize > 0 { offset := (page - 1) * pageSize session = session.Offset(offset).Limit(pageSize) } var list []model.SunoJob err := session.Order("id desc").Find(&list).Error if err != nil { resp.ERROR(c, err.Error()) return } // 转换为 VO items := make([]vo.SunoJob, 0) for _, v := range list { var item vo.SunoJob err = utils.CopyObject(v, &item) if err != nil { continue } items = append(items, item) } resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items)) } func (h *SunoHandler) Remove(c *gin.Context) { id := h.GetInt(c, "id", 0) userId := h.GetLoginUserId(c) var job model.SunoJob err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error if err != nil { resp.ERROR(c, err.Error()) return } // 删除任务 h.DB.Delete(&job) // 删除文件 _ = h.uploader.GetUploadHandler().Delete(job.ThumbImgURL) _ = h.uploader.GetUploadHandler().Delete(job.CoverImgURL) _ = h.uploader.GetUploadHandler().Delete(job.AudioURL) } func (h *SunoHandler) Publish(c *gin.Context) { id := h.GetInt(c, "id", 0) userId := h.GetLoginUserId(c) publish := h.GetBool(c, "publish") err := h.DB.Model(&model.SunoJob{}).Where("id", id).Where("user_id", userId).UpdateColumn("publish", publish).Error if err != nil { resp.ERROR(c, err.Error()) return } resp.SUCCESS(c) }