mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-10 11:13:42 +08:00
adjust chat records layout styles
This commit is contained in:
@@ -224,6 +224,7 @@ func needLogin(c *gin.Context) bool {
|
||||
c.Request.URL.Path == "/api/payment/wechat/notify" ||
|
||||
c.Request.URL.Path == "/api/payment/doPay" ||
|
||||
c.Request.URL.Path == "/api/payment/payWays" ||
|
||||
c.Request.URL.Path == "/api/suno/client" ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
|
||||
|
||||
@@ -88,8 +88,7 @@ type SunoTask struct {
|
||||
Title string `json:"title"`
|
||||
RefTaskId string `json:"ref_task_id"`
|
||||
RefSongId string `json:"ref_song_id"`
|
||||
Lyrics string `json:"lyrics"` // 歌词:自定义模式
|
||||
Prompt string `json:"prompt"` // 提示词:灵感模式
|
||||
Prompt string `json:"prompt"` // 提示词/歌词
|
||||
Tags string `json:"tags"`
|
||||
Model string `json:"model"`
|
||||
Instrumental bool `json:"instrumental"` // 是否纯音乐
|
||||
|
||||
@@ -11,48 +11,55 @@ import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/suno"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SunoHandler struct {
|
||||
BaseHandler
|
||||
service *suno.Service
|
||||
service *suno.Service
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
func NewSunoHandler(app *core.AppServer, db *gorm.DB) *SunoHandler {
|
||||
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager) *SunoHandler {
|
||||
return &SunoHandler{
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
DB: db,
|
||||
},
|
||||
service: service,
|
||||
uploader: uploader,
|
||||
}
|
||||
}
|
||||
|
||||
// Client WebSocket 客户端,用于通知任务状态变更
|
||||
func (h *SunoHandler) Client(c *gin.Context) {
|
||||
//ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
//if err != nil {
|
||||
// logger.Error(err)
|
||||
// c.Abort()
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//userId := h.GetInt(c, "user_id", 0)
|
||||
//if userId == 0 {
|
||||
// logger.Info("Invalid user ID")
|
||||
// c.Abort()
|
||||
// return
|
||||
//}
|
||||
//
|
||||
////client := types.NewWsClient(ws)
|
||||
//logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
if userId == 0 {
|
||||
logger.Info("Invalid user ID")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
h.service.Clients.Put(uint(userId), client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
}
|
||||
|
||||
func (h *SunoHandler) Create(c *gin.Context) {
|
||||
@@ -88,6 +95,9 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
ExtendSecs: data.ExtendSecs,
|
||||
Power: h.App.SysConfig.SunoPower,
|
||||
}
|
||||
if data.Lyrics != "" {
|
||||
job.Prompt = data.Lyrics
|
||||
}
|
||||
tx := h.DB.Create(&job)
|
||||
if tx.Error != nil {
|
||||
resp.ERROR(c, tx.Error.Error())
|
||||
@@ -100,7 +110,6 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
UserId: job.UserId,
|
||||
Type: job.Type,
|
||||
Title: job.Title,
|
||||
Lyrics: data.Lyrics,
|
||||
RefTaskId: data.RefTaskId,
|
||||
RefSongId: data.RefSongId,
|
||||
ExtendSecs: data.ExtendSecs,
|
||||
@@ -128,19 +137,74 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
var itemVo vo.SunoJob
|
||||
_ = utils.CopyObject(job, &itemVo)
|
||||
resp.SUCCESS(c, itemVo)
|
||||
client := h.service.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
func (h *SunoHandler) List(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
page := h.GetInt(c, "page", 0)
|
||||
pageSize := h.GetInt(c, "page_size", 0)
|
||||
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
|
||||
|
||||
// 统计总数
|
||||
var total int64
|
||||
session.Debug().Model(&model.SunoJob{}).Count(&total)
|
||||
|
||||
if page > 0 && pageSize > 0 {
|
||||
offset := (page - 1) * pageSize
|
||||
session = session.Offset(offset).Limit(pageSize)
|
||||
}
|
||||
var list []model.SunoJob
|
||||
err := session.Order("id desc").Find(&list).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为 VO
|
||||
items := make([]vo.SunoJob, 0)
|
||||
for _, v := range list {
|
||||
var item vo.SunoJob
|
||||
err = utils.CopyObject(v, &item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
|
||||
}
|
||||
|
||||
func (h *SunoHandler) Remove(c *gin.Context) {
|
||||
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetLoginUserId(c)
|
||||
var job model.SunoJob
|
||||
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 删除任务
|
||||
h.DB.Delete(&job)
|
||||
// 删除文件
|
||||
_ = h.uploader.GetUploadHandler().Delete(job.ThumbImgURL)
|
||||
_ = h.uploader.GetUploadHandler().Delete(job.CoverImgURL)
|
||||
_ = h.uploader.GetUploadHandler().Delete(job.AudioURL)
|
||||
}
|
||||
|
||||
func (h *SunoHandler) Publish(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
userId := h.GetLoginUserId(c)
|
||||
publish := h.GetBool(c, "publish")
|
||||
err := h.DB.Model(&model.SunoJob{}).Where("id", id).Where("user_id", userId).UpdateColumn("publish", publish).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -214,6 +214,8 @@ func main() {
|
||||
fx.Invoke(func(s *suno.Service) {
|
||||
s.Run()
|
||||
s.SyncTaskProgress()
|
||||
s.CheckTaskNotify()
|
||||
s.DownloadImages()
|
||||
}),
|
||||
|
||||
fx.Provide(payment.NewAlipayService),
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/sd"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
@@ -53,6 +54,25 @@ func (s *Service) PushTask(task types.SunoTask) {
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
// 将数据库中未提交的人物加载到队列
|
||||
var jobs []model.SunoJob
|
||||
s.db.Where("task_id", "").Find(&jobs)
|
||||
for _, v := range jobs {
|
||||
s.PushTask(types.SunoTask{
|
||||
Id: v.Id,
|
||||
Channel: v.Channel,
|
||||
UserId: v.UserId,
|
||||
Type: v.Type,
|
||||
Title: v.Title,
|
||||
RefTaskId: v.RefTaskId,
|
||||
RefSongId: v.RefSongId,
|
||||
Prompt: v.Prompt,
|
||||
Tags: v.Tags,
|
||||
Model: v.ModelName,
|
||||
Instrumental: v.Instrumental,
|
||||
ExtendSecs: v.ExtendSecs,
|
||||
})
|
||||
}
|
||||
logger.Info("Starting Suno job consumer...")
|
||||
go func() {
|
||||
for {
|
||||
@@ -83,7 +103,7 @@ func (s *Service) Run() {
|
||||
}
|
||||
|
||||
type RespVo struct {
|
||||
Code int `json:"code"`
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data string `json:"data"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
@@ -111,7 +131,7 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
|
||||
if task.Type == 1 {
|
||||
reqBody["gpt_description_prompt"] = task.Prompt
|
||||
} else { // 自定义模式
|
||||
reqBody["prompt"] = task.Lyrics
|
||||
reqBody["prompt"] = task.Prompt
|
||||
reqBody["tags"] = task.Tags
|
||||
reqBody["mv"] = task.Model
|
||||
reqBody["title"] = task.Title
|
||||
@@ -131,12 +151,77 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return RespVo{}, fmt.Errorf("解析API数据失败:%s", string(body))
|
||||
return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||
}
|
||||
res.Channel = apiKey.ApiURL
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *Service) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running Suno task notify checking ...")
|
||||
for {
|
||||
var message sd.NotifyMessage
|
||||
err := s.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := s.Clients.Get(uint(message.UserId))
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) DownloadImages() {
|
||||
go func() {
|
||||
var items []model.SunoJob
|
||||
for {
|
||||
res := s.db.Where("progress", 102).Find(&items)
|
||||
if res.Error != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, v := range items {
|
||||
// 下载图片和音频
|
||||
logger.Infof("try download thumb image: %s", v.ThumbImgURL)
|
||||
thumbURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.ThumbImgURL, true)
|
||||
if err != nil {
|
||||
logger.Errorf("download image with error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("try download cover image: %s", v.CoverImgURL)
|
||||
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverImgURL, true)
|
||||
if err != nil {
|
||||
logger.Errorf("download image with error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("try download audio: %s", v.AudioURL)
|
||||
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true)
|
||||
if err != nil {
|
||||
logger.Errorf("download audio with error: %v", err)
|
||||
continue
|
||||
}
|
||||
v.ThumbImgURL = thumbURL
|
||||
v.CoverImgURL = coverURL
|
||||
v.AudioURL = audioURL
|
||||
v.Progress = 100
|
||||
s.db.Updates(&v)
|
||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: sd.Finished})
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// SyncTaskProgress 异步拉取任务
|
||||
func (s *Service) SyncTaskProgress() {
|
||||
go func() {
|
||||
@@ -167,7 +252,7 @@ func (s *Service) SyncTaskProgress() {
|
||||
tx := s.db.Begin()
|
||||
for _, v := range task.Data.Data {
|
||||
job.Id = 0
|
||||
job.Progress = 100
|
||||
job.Progress = 102 // 102 表示资源未下载完成
|
||||
job.Title = v.Title
|
||||
job.SongId = v.Id
|
||||
job.Duration = int(v.Metadata.Duration)
|
||||
@@ -175,26 +260,9 @@ func (s *Service) SyncTaskProgress() {
|
||||
job.Tags = v.Metadata.Tags
|
||||
job.ModelName = v.ModelName
|
||||
job.RawData = utils.JsonEncode(v)
|
||||
|
||||
// 下载图片和音频
|
||||
thumbURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.ImageUrl, true)
|
||||
if err != nil {
|
||||
logger.Errorf("download image with error: %v", err)
|
||||
continue
|
||||
}
|
||||
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.ImageLargeUrl, true)
|
||||
if err != nil {
|
||||
logger.Errorf("download image with error: %v", err)
|
||||
continue
|
||||
}
|
||||
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioUrl, true)
|
||||
if err != nil {
|
||||
logger.Errorf("download audio with error: %v", err)
|
||||
continue
|
||||
}
|
||||
job.ThumbImgURL = thumbURL
|
||||
job.CoverImgURL = coverURL
|
||||
job.AudioURL = audioURL
|
||||
job.ThumbImgURL = v.ImageUrl
|
||||
job.CoverImgURL = v.ImageLargeUrl
|
||||
job.AudioURL = v.AudioUrl
|
||||
|
||||
if err = tx.Create(&job).Error; err != nil {
|
||||
logger.Error("create job with error: %v", err)
|
||||
@@ -212,13 +280,13 @@ func (s *Service) SyncTaskProgress() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
} else if task.Data.FailReason != "" {
|
||||
job.Progress = 101
|
||||
job.ErrMsg = task.Data.FailReason
|
||||
s.db.Updates(&job)
|
||||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -285,7 +353,7 @@ func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error)
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return QueryRespVo{}, fmt.Errorf("解析API数据失败:%s", string(body))
|
||||
return QueryRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||
}
|
||||
|
||||
return res, nil
|
||||
|
||||
Reference in New Issue
Block a user