diff --git a/CHANGELOG.md b/CHANGELOG.md index 867f9b27..a190b0e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * 功能优化:Suno 支持合成完整歌曲,和上传自己的音乐作品进行二次创作 * Bug修复:手机端角色和模型选择不生效 * Bug修复:用户登录过期之后聊天页面出现大量报错,需要刷新页面才能正常 +* 功能优化:优化聊天页面 Websocket 断线重连代码,提高用户体验 * 功能新增:支持 Luma 文生视频功能 ## v4.1.2 diff --git a/api/core/types/config.go b/api/core/types/config.go index 05dec367..9638f620 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -150,8 +150,9 @@ type SystemConfig struct { MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力 MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力 SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力 - DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力 + DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力 SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力 + LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力 WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址 diff --git a/api/handler/suno_handler.go b/api/handler/suno_handler.go index 9151a0eb..721ac4e0 100644 --- a/api/handler/suno_handler.go +++ b/api/handler/suno_handler.go @@ -166,8 +166,8 @@ func (h *SunoHandler) Create(c *gin.Context) { func (h *SunoHandler) List(c *gin.Context) { userId := h.GetLoginUserId(c) - page := h.GetInt(c, "page", 0) - pageSize := h.GetInt(c, "page_size", 0) + page := h.GetInt(c, "page", 1) + pageSize := h.GetInt(c, "page_size", 20) session := h.DB.Session(&gorm.Session{}).Where("user_id", userId) // 统计总数 diff --git a/api/handler/video_handler.go b/api/handler/video_handler.go new file mode 100644 index 00000000..d7e7b5bb --- /dev/null +++ b/api/handler/video_handler.go @@ -0,0 +1,240 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core" + "geekai/core/types" + "geekai/service/oss" + "geekai/service/video" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "gorm.io/gorm" + "net/http" + "time" +) + +type VideoHandler struct { + BaseHandler + service *video.Service + uploader *oss.UploaderManager +} + +func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager) *VideoHandler { + return &VideoHandler{ + BaseHandler: BaseHandler{ + App: app, + DB: db, + }, + service: service, + uploader: uploader, + } +} + +// Client WebSocket 客户端,用于通知任务状态变更 +func (h *VideoHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + c.Abort() + return + } + + userId := h.GetInt(c, "user_id", 0) + if userId == 0 { + logger.Info("Invalid user ID") + c.Abort() + return + } + + client := types.NewWsClient(ws) + h.service.Clients.Put(uint(userId), client) + logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) +} + +func (h *VideoHandler) LumaCreate(c *gin.Context) { + + var data struct { + Prompt string `json:"prompt"` + FirstFrameImg string `json:"first_frame_img,omitempty"` + EndFrameImg string `json:"end_frame_img,omitempty"` + ExpandPrompt bool `json:"expand_prompt,omitempty"` + Loop bool `json:"loop,omitempty"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + if data.Prompt == "" { + resp.ERROR(c, "prompt is needed") + return + } + + userId := int(h.GetLoginUserId(c)) + params := types.VideoParams{ + PromptOptimize: data.ExpandPrompt, + Loop: data.Loop, + StartImgURL: data.FirstFrameImg, + EndImgURL: data.EndFrameImg, + } + // 插入数据库 + job := model.VideoJob{ + UserId: userId, + Type: types.VideoLuma, + Prompt: data.Prompt, + Power: h.App.SysConfig.LumaPower, + Params: utils.JsonEncode(params), + } + tx := h.DB.Create(&job) + if tx.Error != nil { + resp.ERROR(c, tx.Error.Error()) + return + } + + // 创建任务 + h.service.PushTask(types.VideoTask{ + Id: job.Id, + UserId: userId, + Type: types.VideoLuma, + Prompt: data.Prompt, + Params: params, + }) + + // update user's power + tx = h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) + // 记录算力变化日志 + if tx.Error == nil && tx.RowsAffected > 0 { + user, _ := h.GetLoginUser(c) + h.DB.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerConsume, + Amount: job.Power, + Balance: user.Power - job.Power, + Mark: types.PowerSub, + Model: "luma", + Remark: fmt.Sprintf("Luma 文生视频,任务ID:%d", job.Id), + CreatedAt: time.Now(), + }) + } + + client := h.service.Clients.Get(uint(job.UserId)) + if client != nil { + _ = client.Send([]byte("Task Updated")) + } + resp.SUCCESS(c) +} + +func (h *VideoHandler) List(c *gin.Context) { + userId := h.GetLoginUserId(c) + t := c.Query("type") + page := h.GetInt(c, "page", 1) + pageSize := h.GetInt(c, "page_size", 20) + session := h.DB.Session(&gorm.Session{}).Where("user_id", userId) + if t != "" { + session = session.Where("type", t) + } + + // 统计总数 + var total int64 + session.Model(&model.VideoJob{}).Count(&total) + + if page > 0 && pageSize > 0 { + offset := (page - 1) * pageSize + session = session.Offset(offset).Limit(pageSize) + } + var list []model.VideoJob + err := session.Order("id desc").Find(&list).Error + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + // 转换为 VO + items := make([]vo.VideoJob, 0) + for _, v := range list { + var item vo.VideoJob + err = utils.CopyObject(v, &item) + if err != nil { + continue + } + item.CreatedAt = v.CreatedAt.Unix() + items = append(items, item) + } + + resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items)) +} + +func (h *VideoHandler) Remove(c *gin.Context) { + id := h.GetInt(c, "id", 0) + userId := h.GetLoginUserId(c) + var job model.VideoJob + err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error + if err != nil { + resp.ERROR(c, err.Error()) + return + } + // 删除任务 + tx := h.DB.Begin() + if err := tx.Delete(&job).Error; err != nil { + tx.Rollback() + resp.ERROR(c, err.Error()) + return + } + + // 如果任务未完成,或者任务失败,则恢复用户算力 + if job.Progress != 100 { + err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error + if err != nil { + tx.Rollback() + resp.ERROR(c, err.Error()) + return + } + var user model.User + tx.Where("id = ?", job.UserId).First(&user) + err = tx.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerRefund, + Amount: job.Power, + Balance: user.Power, + Mark: types.PowerAdd, + Model: "luma", + Remark: fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), + CreatedAt: time.Now(), + }).Error + if err != nil { + tx.Rollback() + resp.ERROR(c, err.Error()) + return + } + } + tx.Commit() + + // 删除文件 + _ = h.uploader.GetUploadHandler().Delete(job.CoverURL) + _ = h.uploader.GetUploadHandler().Delete(job.VideoURL) +} + +func (h *VideoHandler) Publish(c *gin.Context) { + id := h.GetInt(c, "id", 0) + userId := h.GetLoginUserId(c) + publish := h.GetBool(c, "publish") + err := h.DB.Model(&model.VideoJob{}).Where("id", id).Where("user_id", userId).UpdateColumn("publish", publish).Error + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c) +} diff --git a/api/main.go b/api/main.go index 31dfc0f6..e3d0ae7a 100644 --- a/api/main.go +++ b/api/main.go @@ -24,6 +24,7 @@ import ( "geekai/service/sd" "geekai/service/sms" "geekai/service/suno" + "geekai/service/video" "geekai/store" "io" "log" @@ -201,6 +202,13 @@ func main() { s.CheckTaskNotify() s.DownloadFiles() }), + fx.Provide(video.NewService), + fx.Invoke(func(s *video.Service) { + s.Run() + s.SyncTaskProgress() + s.CheckTaskNotify() + s.DownloadFiles() + }), fx.Provide(payment.NewAlipayService), fx.Provide(payment.NewHuPiPay), @@ -484,6 +492,15 @@ func main() { group.GET("play", h.Play) group.POST("lyric", h.Lyric) }), + fx.Provide(handler.NewVideoHandler), + fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) { + group := s.Engine.Group("/api/video") + group.Any("client", h.Client) + group.POST("luma/create", h.LumaCreate) + group.GET("list", h.List) + group.GET("remove", h.Remove) + group.GET("publish", h.Publish) + }), fx.Provide(handler.NewTestHandler), fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) { group := s.Engine.Group("/api/test") diff --git a/api/service/suno/service.go b/api/service/suno/service.go index 53b213ef..26c545c9 100644 --- a/api/service/suno/service.go +++ b/api/service/suno/service.go @@ -242,6 +242,10 @@ func (s *Service) Upload(task types.SunoTask) (RespVo, error) { return RespVo{}, fmt.Errorf("请求 API 出错:%v", err) } + if r.StatusCode != 200 { + return RespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String()) + } + body, _ := io.ReadAll(r.Body) err = json.Unmarshal(body, &res) if err != nil { diff --git a/api/service/video/luma.go b/api/service/video/luma.go index c94d2565..3bea43a9 100644 --- a/api/service/video/luma.go +++ b/api/service/video/luma.go @@ -82,8 +82,8 @@ func (s *Service) Run() { logger.Errorf("taking task with error: %v", err) continue } - var r RespVo - r, err = s.CreateLuma(task) + var r LumaRespVo + r, err = s.LumaCreate(task) if err != nil { logger.Errorf("create task with error: %v", err) s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ @@ -96,14 +96,15 @@ func (s *Service) Run() { // 更新任务信息 s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ - "task_id": r.Id, - "channel": r.Channel, + "task_id": r.Id, + "channel": r.Channel, + "prompt_ext": r.Prompt, }) } }() } -type RespVo struct { +type LumaRespVo struct { Id string `json:"id"` Prompt string `json:"prompt"` State string `json:"state"` @@ -114,7 +115,7 @@ type RespVo struct { Channel string `json:"channel,omitempty"` } -func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) { +func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) { // 读取 API KEY var apiKey model.ApiKey session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true) @@ -123,7 +124,7 @@ func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) { } tx := session.Order("last_used_at DESC").First(&apiKey) if tx.Error != nil { - return RespVo{}, errors.New("no available API KEY for Suno") + return LumaRespVo{}, errors.New("no available API KEY for Luma") } reqBody := map[string]interface{}{ @@ -133,7 +134,7 @@ func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) { "image_url": task.Params.StartImgURL, "image_end_url": task.Params.EndImgURL, } - var res RespVo + var res LumaRespVo apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL) logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody) r, err := req.C().R(). @@ -141,13 +142,17 @@ func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) { SetBody(reqBody). Post(apiURL) if err != nil { - return RespVo{}, fmt.Errorf("请求 API 出错:%v", err) + return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err) + } + + if r.StatusCode != 200 { + return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String()) } body, _ := io.ReadAll(r.Body) err = json.Unmarshal(body, &res) if err != nil { - return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) + return LumaRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) } // update the last_use_at for api key @@ -180,7 +185,7 @@ func (s *Service) CheckTaskNotify() { func (s *Service) DownloadFiles() { go func() { - var items []model.SunoJob + var items []model.VideoJob for { res := s.db.Where("progress", 102).Find(&items) if res.Error != nil { @@ -188,22 +193,13 @@ func (s *Service) DownloadFiles() { } for _, v := range items { - // 下载图片和音频 - logger.Infof("try download cover image: %s", v.CoverURL) - coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true) - if err != nil { - logger.Errorf("download image with error: %v", err) - continue - } - - logger.Infof("try download audio: %s", v.AudioURL) - audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true) + logger.Infof("try download video: %s", v.VideoURL) + videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true) if err != nil { logger.Errorf("download audio with error: %v", err) continue } - v.CoverURL = coverURL - v.AudioURL = audioURL + v.VideoURL = videoURL v.Progress = 100 s.db.Updates(&v) s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished}) @@ -217,7 +213,7 @@ func (s *Service) DownloadFiles() { // SyncTaskProgress 异步拉取任务 func (s *Service) SyncTaskProgress() { go func() { - var jobs []model.SunoJob + var jobs []model.VideoJob for { res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs) if res.Error != nil { @@ -225,60 +221,14 @@ func (s *Service) SyncTaskProgress() { } for _, job := range jobs { - task, err := s.QueryTask(job.TaskId, job.Channel) + task, err := s.QueryLumaTask(job.TaskId, job.Channel) if err != nil { logger.Errorf("query task with error: %v", err) continue } - if task.Code != "success" { - logger.Errorf("query task with error: %v", task.Message) - continue - } + logger.Debugf("task: %+v", task) - logger.Debugf("task: %+v", task.Data.Status) - // 任务完成,删除旧任务插入两条新任务 - if task.Data.Status == "SUCCESS" { - var jobId = job.Id - var flag = false - tx := s.db.Begin() - for _, v := range task.Data.Data { - job.Id = 0 - job.Progress = 102 // 102 表示资源未下载完成 - job.Title = v.Title - job.SongId = v.Id - job.Duration = int(v.Metadata.Duration) - job.Prompt = v.Metadata.Prompt - job.Tags = v.Metadata.Tags - job.ModelName = v.ModelName - job.RawData = utils.JsonEncode(v) - job.CoverURL = v.ImageLargeUrl - job.AudioURL = v.AudioUrl - - if err = tx.Create(&job).Error; err != nil { - logger.Error("create job with error: %v", err) - tx.Rollback() - break - } - flag = true - } - - // 删除旧任务 - if flag { - if err = tx.Delete(&model.SunoJob{}, "id = ?", jobId).Error; err != nil { - logger.Error("create job with error: %v", err) - tx.Rollback() - continue - } - } - tx.Commit() - - } else if task.Data.FailReason != "" { - job.Progress = service.FailTaskProgress - job.ErrMsg = task.Data.FailReason - s.db.Updates(&job) - s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed}) - } } time.Sleep(time.Second * 10) @@ -286,42 +236,22 @@ func (s *Service) SyncTaskProgress() { }() } -type QueryRespVo struct { - Code string `json:"code"` - Message string `json:"message"` - Data struct { - TaskId string `json:"task_id"` - Action string `json:"action"` - Status string `json:"status"` - FailReason string `json:"fail_reason"` - SubmitTime int `json:"submit_time"` - StartTime int `json:"start_time"` - FinishTime int `json:"finish_time"` - Progress string `json:"progress"` - Data []struct { - Id string `json:"id"` - Title string `json:"title"` - Status string `json:"status"` - Metadata struct { - Tags string `json:"tags"` - Type string `json:"type"` - Prompt string `json:"prompt"` - Stream bool `json:"stream"` - Duration float64 `json:"duration"` - ErrorMessage interface{} `json:"error_message"` - } `json:"metadata"` - AudioUrl string `json:"audio_url"` - ImageUrl string `json:"image_url"` - VideoUrl string `json:"video_url"` - ModelName string `json:"model_name"` - DisplayName string `json:"display_name"` - ImageLargeUrl string `json:"image_large_url"` - MajorModelVersion string `json:"major_model_version"` - } `json:"data"` - } `json:"data"` +type LumaTaskVo struct { + Id string `json:"id"` + Liked interface{} `json:"liked"` + State string `json:"state"` + Video struct { + Url string `json:"url"` + Width int `json:"width"` + Height int `json:"height"` + DownloadUrl string `json:"download_url"` + } `json:"video"` + Prompt string `json:"prompt"` + CreatedAt time.Time `json:"created_at"` + EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"` } -func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) { +func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) { // 读取 API KEY var apiKey model.ApiKey tx := s.db.Session(&gorm.Session{}).Where("type", "suno"). @@ -329,22 +259,22 @@ func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) Where("enabled", true). Order("last_used_at DESC").First(&apiKey) if tx.Error != nil { - return QueryRespVo{}, errors.New("no available API KEY for Suno") + return LumaTaskVo{}, errors.New("no available API KEY for Suno") } - apiURL := fmt.Sprintf("%s/suno/fetch/%s", apiKey.ApiURL, taskId) - var res QueryRespVo + apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId) + var res LumaTaskVo r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL) if err != nil { - return QueryRespVo{}, fmt.Errorf("请求 API 失败:%v", err) + return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err) } defer r.Body.Close() body, _ := io.ReadAll(r.Body) err = json.Unmarshal(body, &res) if err != nil { - return QueryRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) + return LumaTaskVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) } return res, nil diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index c77a9972..b500026c 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -547,7 +547,7 @@ const removeChat = function (chat) { return e1.id === e2.id }) // 重置会话 - newChat(); + _newChat(); }).catch(e => { ElMessage.error("操作失败:" + e.message); }) diff --git a/web/src/views/admin/SysConfig.vue b/web/src/views/admin/SysConfig.vue index f3ece4bc..a1f21979 100644 --- a/web/src/views/admin/SysConfig.vue +++ b/web/src/views/admin/SysConfig.vue @@ -302,6 +302,9 @@ + + +