diff --git a/api/core/app_server.go b/api/core/app_server.go index 7a378dd0..9f00601d 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -224,6 +224,7 @@ func needLogin(c *gin.Context) bool { c.Request.URL.Path == "/api/payment/wechat/notify" || c.Request.URL.Path == "/api/payment/doPay" || c.Request.URL.Path == "/api/payment/payWays" || + c.Request.URL.Path == "/api/suno/client" || strings.HasPrefix(c.Request.URL.Path, "/api/test") || strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") || strings.HasPrefix(c.Request.URL.Path, "/api/config/") || diff --git a/api/core/types/task.go b/api/core/types/task.go index 95e7fc9f..0fb451b7 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -88,8 +88,7 @@ type SunoTask struct { Title string `json:"title"` RefTaskId string `json:"ref_task_id"` RefSongId string `json:"ref_song_id"` - Lyrics string `json:"lyrics"` // 歌词:自定义模式 - Prompt string `json:"prompt"` // 提示词:灵感模式 + Prompt string `json:"prompt"` // 提示词/歌词 Tags string `json:"tags"` Model string `json:"model"` Instrumental bool `json:"instrumental"` // 是否纯音乐 diff --git a/api/handler/suno_handler.go b/api/handler/suno_handler.go index 98398b9d..fb14283e 100644 --- a/api/handler/suno_handler.go +++ b/api/handler/suno_handler.go @@ -11,48 +11,55 @@ import ( "fmt" "geekai/core" "geekai/core/types" + "geekai/service/oss" "geekai/service/suno" "geekai/store/model" "geekai/store/vo" "geekai/utils" "geekai/utils/resp" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "gorm.io/gorm" + "net/http" "time" ) type SunoHandler struct { BaseHandler - service *suno.Service + service *suno.Service + uploader *oss.UploaderManager } -func NewSunoHandler(app *core.AppServer, db *gorm.DB) *SunoHandler { +func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager) *SunoHandler { return &SunoHandler{ BaseHandler: BaseHandler{ App: app, DB: db, }, + service: service, + uploader: uploader, } } // Client WebSocket 客户端,用于通知任务状态变更 func (h *SunoHandler) Client(c *gin.Context) { - //ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) - //if err != nil { - // logger.Error(err) - // c.Abort() - // return - //} - // - //userId := h.GetInt(c, "user_id", 0) - //if userId == 0 { - // logger.Info("Invalid user ID") - // c.Abort() - // return - //} - // - ////client := types.NewWsClient(ws) - //logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + c.Abort() + return + } + + userId := h.GetInt(c, "user_id", 0) + if userId == 0 { + logger.Info("Invalid user ID") + c.Abort() + return + } + + client := types.NewWsClient(ws) + h.service.Clients.Put(uint(userId), client) + logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) } func (h *SunoHandler) Create(c *gin.Context) { @@ -88,6 +95,9 @@ func (h *SunoHandler) Create(c *gin.Context) { ExtendSecs: data.ExtendSecs, Power: h.App.SysConfig.SunoPower, } + if data.Lyrics != "" { + job.Prompt = data.Lyrics + } tx := h.DB.Create(&job) if tx.Error != nil { resp.ERROR(c, tx.Error.Error()) @@ -100,7 +110,6 @@ func (h *SunoHandler) Create(c *gin.Context) { UserId: job.UserId, Type: job.Type, Title: job.Title, - Lyrics: data.Lyrics, RefTaskId: data.RefTaskId, RefSongId: data.RefSongId, ExtendSecs: data.ExtendSecs, @@ -128,19 +137,74 @@ func (h *SunoHandler) Create(c *gin.Context) { }) } - var itemVo vo.SunoJob - _ = utils.CopyObject(job, &itemVo) - resp.SUCCESS(c, itemVo) + client := h.service.Clients.Get(uint(job.UserId)) + if client != nil { + _ = client.Send([]byte("Task Updated")) + } + resp.SUCCESS(c) } func (h *SunoHandler) List(c *gin.Context) { + userId := h.GetLoginUserId(c) + page := h.GetInt(c, "page", 0) + pageSize := h.GetInt(c, "page_size", 0) + session := h.DB.Session(&gorm.Session{}).Where("user_id", userId) + // 统计总数 + var total int64 + session.Debug().Model(&model.SunoJob{}).Count(&total) + + if page > 0 && pageSize > 0 { + offset := (page - 1) * pageSize + session = session.Offset(offset).Limit(pageSize) + } + var list []model.SunoJob + err := session.Order("id desc").Find(&list).Error + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + // 转换为 VO + items := make([]vo.SunoJob, 0) + for _, v := range list { + var item vo.SunoJob + err = utils.CopyObject(v, &item) + if err != nil { + continue + } + items = append(items, item) + } + + resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items)) } func (h *SunoHandler) Remove(c *gin.Context) { - + id := h.GetInt(c, "id", 0) + userId := h.GetLoginUserId(c) + var job model.SunoJob + err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error + if err != nil { + resp.ERROR(c, err.Error()) + return + } + // 删除任务 + h.DB.Delete(&job) + // 删除文件 + _ = h.uploader.GetUploadHandler().Delete(job.ThumbImgURL) + _ = h.uploader.GetUploadHandler().Delete(job.CoverImgURL) + _ = h.uploader.GetUploadHandler().Delete(job.AudioURL) } func (h *SunoHandler) Publish(c *gin.Context) { + id := h.GetInt(c, "id", 0) + userId := h.GetLoginUserId(c) + publish := h.GetBool(c, "publish") + err := h.DB.Model(&model.SunoJob{}).Where("id", id).Where("user_id", userId).UpdateColumn("publish", publish).Error + if err != nil { + resp.ERROR(c, err.Error()) + return + } + resp.SUCCESS(c) } diff --git a/api/main.go b/api/main.go index 21d26be5..d5541fb8 100644 --- a/api/main.go +++ b/api/main.go @@ -214,6 +214,8 @@ func main() { fx.Invoke(func(s *suno.Service) { s.Run() s.SyncTaskProgress() + s.CheckTaskNotify() + s.DownloadImages() }), fx.Provide(payment.NewAlipayService), diff --git a/api/service/suno/service.go b/api/service/suno/service.go index df1e5876..83f0fc58 100644 --- a/api/service/suno/service.go +++ b/api/service/suno/service.go @@ -14,6 +14,7 @@ import ( "geekai/core/types" logger2 "geekai/logger" "geekai/service/oss" + "geekai/service/sd" "geekai/store" "geekai/store/model" "geekai/utils" @@ -53,6 +54,25 @@ func (s *Service) PushTask(task types.SunoTask) { } func (s *Service) Run() { + // 将数据库中未提交的人物加载到队列 + var jobs []model.SunoJob + s.db.Where("task_id", "").Find(&jobs) + for _, v := range jobs { + s.PushTask(types.SunoTask{ + Id: v.Id, + Channel: v.Channel, + UserId: v.UserId, + Type: v.Type, + Title: v.Title, + RefTaskId: v.RefTaskId, + RefSongId: v.RefSongId, + Prompt: v.Prompt, + Tags: v.Tags, + Model: v.ModelName, + Instrumental: v.Instrumental, + ExtendSecs: v.ExtendSecs, + }) + } logger.Info("Starting Suno job consumer...") go func() { for { @@ -83,7 +103,7 @@ func (s *Service) Run() { } type RespVo struct { - Code int `json:"code"` + Code string `json:"code"` Message string `json:"message"` Data string `json:"data"` Channel string `json:"channel,omitempty"` @@ -111,7 +131,7 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) { if task.Type == 1 { reqBody["gpt_description_prompt"] = task.Prompt } else { // 自定义模式 - reqBody["prompt"] = task.Lyrics + reqBody["prompt"] = task.Prompt reqBody["tags"] = task.Tags reqBody["mv"] = task.Model reqBody["title"] = task.Title @@ -131,12 +151,77 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) { body, _ := io.ReadAll(r.Body) err = json.Unmarshal(body, &res) if err != nil { - return RespVo{}, fmt.Errorf("解析API数据失败:%s", string(body)) + return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) } res.Channel = apiKey.ApiURL return res, nil } +func (s *Service) CheckTaskNotify() { + go func() { + logger.Info("Running Suno task notify checking ...") + for { + var message sd.NotifyMessage + err := s.notifyQueue.LPop(&message) + if err != nil { + continue + } + client := s.Clients.Get(uint(message.UserId)) + if client == nil { + continue + } + err = client.Send([]byte(message.Message)) + if err != nil { + continue + } + } + }() +} + +func (s *Service) DownloadImages() { + go func() { + var items []model.SunoJob + for { + res := s.db.Where("progress", 102).Find(&items) + if res.Error != nil { + continue + } + + for _, v := range items { + // 下载图片和音频 + logger.Infof("try download thumb image: %s", v.ThumbImgURL) + thumbURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.ThumbImgURL, true) + if err != nil { + logger.Errorf("download image with error: %v", err) + continue + } + + logger.Infof("try download cover image: %s", v.CoverImgURL) + coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverImgURL, true) + if err != nil { + logger.Errorf("download image with error: %v", err) + continue + } + + logger.Infof("try download audio: %s", v.AudioURL) + audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true) + if err != nil { + logger.Errorf("download audio with error: %v", err) + continue + } + v.ThumbImgURL = thumbURL + v.CoverImgURL = coverURL + v.AudioURL = audioURL + v.Progress = 100 + s.db.Updates(&v) + s.notifyQueue.RPush(sd.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: sd.Finished}) + } + + time.Sleep(time.Second * 10) + } + }() +} + // SyncTaskProgress 异步拉取任务 func (s *Service) SyncTaskProgress() { go func() { @@ -167,7 +252,7 @@ func (s *Service) SyncTaskProgress() { tx := s.db.Begin() for _, v := range task.Data.Data { job.Id = 0 - job.Progress = 100 + job.Progress = 102 // 102 表示资源未下载完成 job.Title = v.Title job.SongId = v.Id job.Duration = int(v.Metadata.Duration) @@ -175,26 +260,9 @@ func (s *Service) SyncTaskProgress() { job.Tags = v.Metadata.Tags job.ModelName = v.ModelName job.RawData = utils.JsonEncode(v) - - // 下载图片和音频 - thumbURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.ImageUrl, true) - if err != nil { - logger.Errorf("download image with error: %v", err) - continue - } - coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.ImageLargeUrl, true) - if err != nil { - logger.Errorf("download image with error: %v", err) - continue - } - audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioUrl, true) - if err != nil { - logger.Errorf("download audio with error: %v", err) - continue - } - job.ThumbImgURL = thumbURL - job.CoverImgURL = coverURL - job.AudioURL = audioURL + job.ThumbImgURL = v.ImageUrl + job.CoverImgURL = v.ImageLargeUrl + job.AudioURL = v.AudioUrl if err = tx.Create(&job).Error; err != nil { logger.Error("create job with error: %v", err) @@ -212,13 +280,13 @@ func (s *Service) SyncTaskProgress() { continue } } - tx.Commit() } else if task.Data.FailReason != "" { job.Progress = 101 job.ErrMsg = task.Data.FailReason s.db.Updates(&job) + s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed}) } } @@ -285,7 +353,7 @@ func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) body, _ := io.ReadAll(r.Body) err = json.Unmarshal(body, &res) if err != nil { - return QueryRespVo{}, fmt.Errorf("解析API数据失败:%s", string(body)) + return QueryRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) } return res, nil diff --git a/web/src/assets/css/image-mj.styl b/web/src/assets/css/image-mj.styl index 2f08175f..afb7671d 100644 --- a/web/src/assets/css/image-mj.styl +++ b/web/src/assets/css/image-mj.styl @@ -427,22 +427,18 @@ .err-msg-container { overflow hidden word-break break-all - padding 0 10px 20px 10px + padding 15px .title { - font-size 16px + font-size 20px text-align center font-weight bold color #f56c6c - margin-bottom 20px + margin-bottom 30px } - .text { - font-size 14px - color #E9F1F6 - line-height 1.5 - text-overflow ellipsis - height 100px - overflow hidden + .opt { + display flex + justify-content center } } .iconfont { diff --git a/web/src/assets/css/suno.styl b/web/src/assets/css/suno.styl index 2ac02bc0..08a1bca1 100644 --- a/web/src/assets/css/suno.styl +++ b/web/src/assets/css/suno.styl @@ -96,7 +96,7 @@ .right-box { width 100% color rgb(250 247 245) - overflow hidden + overflow auto .list-box { padding 0 0 0 20px @@ -105,9 +105,10 @@ flex-flow row padding 5px 0 cursor pointer + margin-bottom 10px &:hover { - background-color #1C1616 + background-color #2A2525 } .left { @@ -246,11 +247,41 @@ } } - .item.active { + + .task { + height 100px background-color #2A2525 + display flex + margin-bottom 10px + .left { + display flex + justify-content left + align-items center + padding 20px + width 320px + .title { + font-size 14px + color #e1e1e1 + white-space: nowrap; /* 防止文字换行 */ + overflow: hidden; /* 隐藏溢出的内容 */ + text-overflow: ellipsis; /* 用省略号表示溢出的内容 */ + } + } + + .right { + display flex + width 100% + justify-content center + } } } + + .pagination { + padding 10px 20px + display flex + justify-content center + } .music-player { width 100% position: fixed; diff --git a/web/src/components/ChatPrompt.vue b/web/src/components/ChatPrompt.vue index 0603c037..92dd392c 100644 --- a/web/src/components/ChatPrompt.vue +++ b/web/src/components/ChatPrompt.vue @@ -70,7 +70,7 @@
@@ -132,12 +132,12 @@ const content =ref(processPrompt(props.data.content)) const files = ref([]) onMounted(() => { - if (!finalTokens.value) { - httpPost("/api/chat/tokens", {text: props.data.content, model: props.data.model}).then(res => { - finalTokens.value = res.data; - }).catch(() => { - }) - } + // if (!finalTokens.value) { + // httpPost("/api/chat/tokens", {text: props.data.content, model: props.data.model}).then(res => { + // finalTokens.value = res.data; + // }).catch(() => { + // }) + // } const linkRegex = /(https?:\/\/\S+)/g; const links = props.data.content.match(linkRegex); @@ -308,9 +308,10 @@ const isExternalImg = (link, files) => { display flex; width 100%; padding 0 25px; + flex-flow row-reverse .chat-icon { - margin-right 20px; + margin-left 20px; img { width: 36px; @@ -377,6 +378,7 @@ const isExternalImg = (link, files) => { .content-wrapper { display flex + flex-flow row-reverse .content { word-break break-word; padding: 1rem @@ -384,7 +386,7 @@ const isExternalImg = (link, files) => { font-size: var(--content-font-size); overflow: auto; background-color #98e165 - border-radius: 0 10px 10px 10px; + border-radius: 10px 0 10px 10px; img { max-width: 600px; diff --git a/web/src/components/ChatReply.vue b/web/src/components/ChatReply.vue index 6e64dbb0..37edc9c5 100644 --- a/web/src/components/ChatReply.vue +++ b/web/src/components/ChatReply.vue @@ -67,14 +67,13 @@ -