diff --git a/CHANGELOG.md b/CHANGELOG.md
index 867f9b27..a190b0e3 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,6 +8,7 @@
* 功能优化:Suno 支持合成完整歌曲,和上传自己的音乐作品进行二次创作
* Bug修复:手机端角色和模型选择不生效
* Bug修复:用户登录过期之后聊天页面出现大量报错,需要刷新页面才能正常
+* 功能优化:优化聊天页面 Websocket 断线重连代码,提高用户体验
* 功能新增:支持 Luma 文生视频功能
## v4.1.2
diff --git a/api/core/types/config.go b/api/core/types/config.go
index 05dec367..9638f620 100644
--- a/api/core/types/config.go
+++ b/api/core/types/config.go
@@ -150,8 +150,9 @@ type SystemConfig struct {
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
- DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力
+ DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
+ LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
diff --git a/api/handler/suno_handler.go b/api/handler/suno_handler.go
index 9151a0eb..721ac4e0 100644
--- a/api/handler/suno_handler.go
+++ b/api/handler/suno_handler.go
@@ -166,8 +166,8 @@ func (h *SunoHandler) Create(c *gin.Context) {
func (h *SunoHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
- page := h.GetInt(c, "page", 0)
- pageSize := h.GetInt(c, "page_size", 0)
+ page := h.GetInt(c, "page", 1)
+ pageSize := h.GetInt(c, "page_size", 20)
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
// 统计总数
diff --git a/api/handler/video_handler.go b/api/handler/video_handler.go
new file mode 100644
index 00000000..d7e7b5bb
--- /dev/null
+++ b/api/handler/video_handler.go
@@ -0,0 +1,240 @@
+package handler
+
+// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+// * Copyright 2023 The Geek-AI Authors. All rights reserved.
+// * Use of this source code is governed by a Apache-2.0 license
+// * that can be found in the LICENSE file.
+// * @Author yangjian102621@163.com
+// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+
+import (
+ "fmt"
+ "geekai/core"
+ "geekai/core/types"
+ "geekai/service/oss"
+ "geekai/service/video"
+ "geekai/store/model"
+ "geekai/store/vo"
+ "geekai/utils"
+ "geekai/utils/resp"
+ "github.com/gin-gonic/gin"
+ "github.com/gorilla/websocket"
+ "gorm.io/gorm"
+ "net/http"
+ "time"
+)
+
+type VideoHandler struct {
+ BaseHandler
+ service *video.Service
+ uploader *oss.UploaderManager
+}
+
+func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager) *VideoHandler {
+ return &VideoHandler{
+ BaseHandler: BaseHandler{
+ App: app,
+ DB: db,
+ },
+ service: service,
+ uploader: uploader,
+ }
+}
+
+// Client WebSocket 客户端,用于通知任务状态变更
+func (h *VideoHandler) Client(c *gin.Context) {
+ ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
+ if err != nil {
+ logger.Error(err)
+ c.Abort()
+ return
+ }
+
+ userId := h.GetInt(c, "user_id", 0)
+ if userId == 0 {
+ logger.Info("Invalid user ID")
+ c.Abort()
+ return
+ }
+
+ client := types.NewWsClient(ws)
+ h.service.Clients.Put(uint(userId), client)
+ logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
+}
+
+func (h *VideoHandler) LumaCreate(c *gin.Context) {
+
+ var data struct {
+ Prompt string `json:"prompt"`
+ FirstFrameImg string `json:"first_frame_img,omitempty"`
+ EndFrameImg string `json:"end_frame_img,omitempty"`
+ ExpandPrompt bool `json:"expand_prompt,omitempty"`
+ Loop bool `json:"loop,omitempty"`
+ }
+ if err := c.ShouldBindJSON(&data); err != nil {
+ resp.ERROR(c, types.InvalidArgs)
+ return
+ }
+ if data.Prompt == "" {
+ resp.ERROR(c, "prompt is needed")
+ return
+ }
+
+ userId := int(h.GetLoginUserId(c))
+ params := types.VideoParams{
+ PromptOptimize: data.ExpandPrompt,
+ Loop: data.Loop,
+ StartImgURL: data.FirstFrameImg,
+ EndImgURL: data.EndFrameImg,
+ }
+ // 插入数据库
+ job := model.VideoJob{
+ UserId: userId,
+ Type: types.VideoLuma,
+ Prompt: data.Prompt,
+ Power: h.App.SysConfig.LumaPower,
+ Params: utils.JsonEncode(params),
+ }
+ tx := h.DB.Create(&job)
+ if tx.Error != nil {
+ resp.ERROR(c, tx.Error.Error())
+ return
+ }
+
+ // 创建任务
+ h.service.PushTask(types.VideoTask{
+ Id: job.Id,
+ UserId: userId,
+ Type: types.VideoLuma,
+ Prompt: data.Prompt,
+ Params: params,
+ })
+
+ // update user's power
+ tx = h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
+ // 记录算力变化日志
+ if tx.Error == nil && tx.RowsAffected > 0 {
+ user, _ := h.GetLoginUser(c)
+ h.DB.Create(&model.PowerLog{
+ UserId: user.Id,
+ Username: user.Username,
+ Type: types.PowerConsume,
+ Amount: job.Power,
+ Balance: user.Power - job.Power,
+ Mark: types.PowerSub,
+ Model: "luma",
+ Remark: fmt.Sprintf("Luma 文生视频,任务ID:%d", job.Id),
+ CreatedAt: time.Now(),
+ })
+ }
+
+ client := h.service.Clients.Get(uint(job.UserId))
+ if client != nil {
+ _ = client.Send([]byte("Task Updated"))
+ }
+ resp.SUCCESS(c)
+}
+
+func (h *VideoHandler) List(c *gin.Context) {
+ userId := h.GetLoginUserId(c)
+ t := c.Query("type")
+ page := h.GetInt(c, "page", 1)
+ pageSize := h.GetInt(c, "page_size", 20)
+ session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
+ if t != "" {
+ session = session.Where("type", t)
+ }
+
+ // 统计总数
+ var total int64
+ session.Model(&model.VideoJob{}).Count(&total)
+
+ if page > 0 && pageSize > 0 {
+ offset := (page - 1) * pageSize
+ session = session.Offset(offset).Limit(pageSize)
+ }
+ var list []model.VideoJob
+ err := session.Order("id desc").Find(&list).Error
+ if err != nil {
+ resp.ERROR(c, err.Error())
+ return
+ }
+
+ // 转换为 VO
+ items := make([]vo.VideoJob, 0)
+ for _, v := range list {
+ var item vo.VideoJob
+ err = utils.CopyObject(v, &item)
+ if err != nil {
+ continue
+ }
+ item.CreatedAt = v.CreatedAt.Unix()
+ items = append(items, item)
+ }
+
+ resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
+}
+
+func (h *VideoHandler) Remove(c *gin.Context) {
+ id := h.GetInt(c, "id", 0)
+ userId := h.GetLoginUserId(c)
+ var job model.VideoJob
+ err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
+ if err != nil {
+ resp.ERROR(c, err.Error())
+ return
+ }
+ // 删除任务
+ tx := h.DB.Begin()
+ if err := tx.Delete(&job).Error; err != nil {
+ tx.Rollback()
+ resp.ERROR(c, err.Error())
+ return
+ }
+
+ // 如果任务未完成,或者任务失败,则恢复用户算力
+ if job.Progress != 100 {
+ err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error
+ if err != nil {
+ tx.Rollback()
+ resp.ERROR(c, err.Error())
+ return
+ }
+ var user model.User
+ tx.Where("id = ?", job.UserId).First(&user)
+ err = tx.Create(&model.PowerLog{
+ UserId: user.Id,
+ Username: user.Username,
+ Type: types.PowerRefund,
+ Amount: job.Power,
+ Balance: user.Power,
+ Mark: types.PowerAdd,
+ Model: "luma",
+ Remark: fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg),
+ CreatedAt: time.Now(),
+ }).Error
+ if err != nil {
+ tx.Rollback()
+ resp.ERROR(c, err.Error())
+ return
+ }
+ }
+ tx.Commit()
+
+ // 删除文件
+ _ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
+ _ = h.uploader.GetUploadHandler().Delete(job.VideoURL)
+}
+
+func (h *VideoHandler) Publish(c *gin.Context) {
+ id := h.GetInt(c, "id", 0)
+ userId := h.GetLoginUserId(c)
+ publish := h.GetBool(c, "publish")
+ err := h.DB.Model(&model.VideoJob{}).Where("id", id).Where("user_id", userId).UpdateColumn("publish", publish).Error
+ if err != nil {
+ resp.ERROR(c, err.Error())
+ return
+ }
+
+ resp.SUCCESS(c)
+}
diff --git a/api/main.go b/api/main.go
index 31dfc0f6..e3d0ae7a 100644
--- a/api/main.go
+++ b/api/main.go
@@ -24,6 +24,7 @@ import (
"geekai/service/sd"
"geekai/service/sms"
"geekai/service/suno"
+ "geekai/service/video"
"geekai/store"
"io"
"log"
@@ -201,6 +202,13 @@ func main() {
s.CheckTaskNotify()
s.DownloadFiles()
}),
+ fx.Provide(video.NewService),
+ fx.Invoke(func(s *video.Service) {
+ s.Run()
+ s.SyncTaskProgress()
+ s.CheckTaskNotify()
+ s.DownloadFiles()
+ }),
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay),
@@ -484,6 +492,15 @@ func main() {
group.GET("play", h.Play)
group.POST("lyric", h.Lyric)
}),
+ fx.Provide(handler.NewVideoHandler),
+ fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
+ group := s.Engine.Group("/api/video")
+ group.Any("client", h.Client)
+ group.POST("luma/create", h.LumaCreate)
+ group.GET("list", h.List)
+ group.GET("remove", h.Remove)
+ group.GET("publish", h.Publish)
+ }),
fx.Provide(handler.NewTestHandler),
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
group := s.Engine.Group("/api/test")
diff --git a/api/service/suno/service.go b/api/service/suno/service.go
index 53b213ef..26c545c9 100644
--- a/api/service/suno/service.go
+++ b/api/service/suno/service.go
@@ -242,6 +242,10 @@ func (s *Service) Upload(task types.SunoTask) (RespVo, error) {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
+ if r.StatusCode != 200 {
+ return RespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
+ }
+
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
diff --git a/api/service/video/luma.go b/api/service/video/luma.go
index c94d2565..3bea43a9 100644
--- a/api/service/video/luma.go
+++ b/api/service/video/luma.go
@@ -82,8 +82,8 @@ func (s *Service) Run() {
logger.Errorf("taking task with error: %v", err)
continue
}
- var r RespVo
- r, err = s.CreateLuma(task)
+ var r LumaRespVo
+ r, err = s.LumaCreate(task)
if err != nil {
logger.Errorf("create task with error: %v", err)
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
@@ -96,14 +96,15 @@ func (s *Service) Run() {
// 更新任务信息
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
- "task_id": r.Id,
- "channel": r.Channel,
+ "task_id": r.Id,
+ "channel": r.Channel,
+ "prompt_ext": r.Prompt,
})
}
}()
}
-type RespVo struct {
+type LumaRespVo struct {
Id string `json:"id"`
Prompt string `json:"prompt"`
State string `json:"state"`
@@ -114,7 +115,7 @@ type RespVo struct {
Channel string `json:"channel,omitempty"`
}
-func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
+func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
@@ -123,7 +124,7 @@ func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
- return RespVo{}, errors.New("no available API KEY for Suno")
+ return LumaRespVo{}, errors.New("no available API KEY for Luma")
}
reqBody := map[string]interface{}{
@@ -133,7 +134,7 @@ func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
"image_url": task.Params.StartImgURL,
"image_end_url": task.Params.EndImgURL,
}
- var res RespVo
+ var res LumaRespVo
apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
@@ -141,13 +142,17 @@ func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
SetBody(reqBody).
Post(apiURL)
if err != nil {
- return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
+ return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
+ }
+
+ if r.StatusCode != 200 {
+ return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
- return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
+ return LumaRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
}
// update the last_use_at for api key
@@ -180,7 +185,7 @@ func (s *Service) CheckTaskNotify() {
func (s *Service) DownloadFiles() {
go func() {
- var items []model.SunoJob
+ var items []model.VideoJob
for {
res := s.db.Where("progress", 102).Find(&items)
if res.Error != nil {
@@ -188,22 +193,13 @@ func (s *Service) DownloadFiles() {
}
for _, v := range items {
- // 下载图片和音频
- logger.Infof("try download cover image: %s", v.CoverURL)
- coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true)
- if err != nil {
- logger.Errorf("download image with error: %v", err)
- continue
- }
-
- logger.Infof("try download audio: %s", v.AudioURL)
- audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true)
+ logger.Infof("try download video: %s", v.VideoURL)
+ videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
if err != nil {
logger.Errorf("download audio with error: %v", err)
continue
}
- v.CoverURL = coverURL
- v.AudioURL = audioURL
+ v.VideoURL = videoURL
v.Progress = 100
s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
@@ -217,7 +213,7 @@ func (s *Service) DownloadFiles() {
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
- var jobs []model.SunoJob
+ var jobs []model.VideoJob
for {
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
if res.Error != nil {
@@ -225,60 +221,14 @@ func (s *Service) SyncTaskProgress() {
}
for _, job := range jobs {
- task, err := s.QueryTask(job.TaskId, job.Channel)
+ task, err := s.QueryLumaTask(job.TaskId, job.Channel)
if err != nil {
logger.Errorf("query task with error: %v", err)
continue
}
- if task.Code != "success" {
- logger.Errorf("query task with error: %v", task.Message)
- continue
- }
+ logger.Debugf("task: %+v", task)
- logger.Debugf("task: %+v", task.Data.Status)
- // 任务完成,删除旧任务插入两条新任务
- if task.Data.Status == "SUCCESS" {
- var jobId = job.Id
- var flag = false
- tx := s.db.Begin()
- for _, v := range task.Data.Data {
- job.Id = 0
- job.Progress = 102 // 102 表示资源未下载完成
- job.Title = v.Title
- job.SongId = v.Id
- job.Duration = int(v.Metadata.Duration)
- job.Prompt = v.Metadata.Prompt
- job.Tags = v.Metadata.Tags
- job.ModelName = v.ModelName
- job.RawData = utils.JsonEncode(v)
- job.CoverURL = v.ImageLargeUrl
- job.AudioURL = v.AudioUrl
-
- if err = tx.Create(&job).Error; err != nil {
- logger.Error("create job with error: %v", err)
- tx.Rollback()
- break
- }
- flag = true
- }
-
- // 删除旧任务
- if flag {
- if err = tx.Delete(&model.SunoJob{}, "id = ?", jobId).Error; err != nil {
- logger.Error("create job with error: %v", err)
- tx.Rollback()
- continue
- }
- }
- tx.Commit()
-
- } else if task.Data.FailReason != "" {
- job.Progress = service.FailTaskProgress
- job.ErrMsg = task.Data.FailReason
- s.db.Updates(&job)
- s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
- }
}
time.Sleep(time.Second * 10)
@@ -286,42 +236,22 @@ func (s *Service) SyncTaskProgress() {
}()
}
-type QueryRespVo struct {
- Code string `json:"code"`
- Message string `json:"message"`
- Data struct {
- TaskId string `json:"task_id"`
- Action string `json:"action"`
- Status string `json:"status"`
- FailReason string `json:"fail_reason"`
- SubmitTime int `json:"submit_time"`
- StartTime int `json:"start_time"`
- FinishTime int `json:"finish_time"`
- Progress string `json:"progress"`
- Data []struct {
- Id string `json:"id"`
- Title string `json:"title"`
- Status string `json:"status"`
- Metadata struct {
- Tags string `json:"tags"`
- Type string `json:"type"`
- Prompt string `json:"prompt"`
- Stream bool `json:"stream"`
- Duration float64 `json:"duration"`
- ErrorMessage interface{} `json:"error_message"`
- } `json:"metadata"`
- AudioUrl string `json:"audio_url"`
- ImageUrl string `json:"image_url"`
- VideoUrl string `json:"video_url"`
- ModelName string `json:"model_name"`
- DisplayName string `json:"display_name"`
- ImageLargeUrl string `json:"image_large_url"`
- MajorModelVersion string `json:"major_model_version"`
- } `json:"data"`
- } `json:"data"`
+type LumaTaskVo struct {
+ Id string `json:"id"`
+ Liked interface{} `json:"liked"`
+ State string `json:"state"`
+ Video struct {
+ Url string `json:"url"`
+ Width int `json:"width"`
+ Height int `json:"height"`
+ DownloadUrl string `json:"download_url"`
+ } `json:"video"`
+ Prompt string `json:"prompt"`
+ CreatedAt time.Time `json:"created_at"`
+ EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
}
-func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) {
+func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
tx := s.db.Session(&gorm.Session{}).Where("type", "suno").
@@ -329,22 +259,22 @@ func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error)
Where("enabled", true).
Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
- return QueryRespVo{}, errors.New("no available API KEY for Suno")
+ return LumaTaskVo{}, errors.New("no available API KEY for Suno")
}
- apiURL := fmt.Sprintf("%s/suno/fetch/%s", apiKey.ApiURL, taskId)
- var res QueryRespVo
+ apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
+ var res LumaTaskVo
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
if err != nil {
- return QueryRespVo{}, fmt.Errorf("请求 API 失败:%v", err)
+ return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
}
defer r.Body.Close()
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
- return QueryRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
+ return LumaTaskVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
}
return res, nil
diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue
index c77a9972..b500026c 100644
--- a/web/src/views/ChatPlus.vue
+++ b/web/src/views/ChatPlus.vue
@@ -547,7 +547,7 @@ const removeChat = function (chat) {
return e1.id === e2.id
})
// 重置会话
- newChat();
+ _newChat();
}).catch(e => {
ElMessage.error("操作失败:" + e.message);
})
diff --git a/web/src/views/admin/SysConfig.vue b/web/src/views/admin/SysConfig.vue
index f3ece4bc..a1f21979 100644
--- a/web/src/views/admin/SysConfig.vue
+++ b/web/src/views/admin/SysConfig.vue
@@ -302,6 +302,9 @@
+
+
+