luma create and list api is ready

This commit is contained in:
RockYang 2024-09-02 18:08:50 +08:00
parent d70fcbb9f3
commit aa46fdb113
9 changed files with 311 additions and 115 deletions

View File

@ -8,6 +8,7 @@
* 功能优化Suno 支持合成完整歌曲,和上传自己的音乐作品进行二次创作
* Bug修复手机端角色和模型选择不生效
* Bug修复用户登录过期之后聊天页面出现大量报错需要刷新页面才能正常
* 功能优化:优化聊天页面 Websocket 断线重连代码,提高用户体验
* 功能新增:支持 Luma 文生视频功能
## v4.1.2

View File

@ -150,8 +150,9 @@ type SystemConfig struct {
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力
DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址

View File

@ -166,8 +166,8 @@ func (h *SunoHandler) Create(c *gin.Context) {
func (h *SunoHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
// 统计总数

View File

@ -0,0 +1,240 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service/oss"
"geekai/service/video"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"time"
)
type VideoHandler struct {
BaseHandler
service *video.Service
uploader *oss.UploaderManager
}
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager) *VideoHandler {
return &VideoHandler{
BaseHandler: BaseHandler{
App: app,
DB: db,
},
service: service,
uploader: uploader,
}
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *VideoHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.service.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *VideoHandler) LumaCreate(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
FirstFrameImg string `json:"first_frame_img,omitempty"`
EndFrameImg string `json:"end_frame_img,omitempty"`
ExpandPrompt bool `json:"expand_prompt,omitempty"`
Loop bool `json:"loop,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Prompt == "" {
resp.ERROR(c, "prompt is needed")
return
}
userId := int(h.GetLoginUserId(c))
params := types.VideoParams{
PromptOptimize: data.ExpandPrompt,
Loop: data.Loop,
StartImgURL: data.FirstFrameImg,
EndImgURL: data.EndFrameImg,
}
// 插入数据库
job := model.VideoJob{
UserId: userId,
Type: types.VideoLuma,
Prompt: data.Prompt,
Power: h.App.SysConfig.LumaPower,
Params: utils.JsonEncode(params),
}
tx := h.DB.Create(&job)
if tx.Error != nil {
resp.ERROR(c, tx.Error.Error())
return
}
// 创建任务
h.service.PushTask(types.VideoTask{
Id: job.Id,
UserId: userId,
Type: types.VideoLuma,
Prompt: data.Prompt,
Params: params,
})
// update user's power
tx = h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "luma",
Remark: fmt.Sprintf("Luma 文生视频任务ID%d", job.Id),
CreatedAt: time.Now(),
})
}
client := h.service.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
func (h *VideoHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
t := c.Query("type")
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
if t != "" {
session = session.Where("type", t)
}
// 统计总数
var total int64
session.Model(&model.VideoJob{}).Count(&total)
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
}
var list []model.VideoJob
err := session.Order("id desc").Find(&list).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 转换为 VO
items := make([]vo.VideoJob, 0)
for _, v := range list {
var item vo.VideoJob
err = utils.CopyObject(v, &item)
if err != nil {
continue
}
item.CreatedAt = v.CreatedAt.Unix()
items = append(items, item)
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
}
func (h *VideoHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
var job model.VideoJob
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 删除任务
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
// 如果任务未完成,或者任务失败,则恢复用户算力
if job.Progress != 100 {
err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
var user model.User
tx.Where("id = ?", job.UserId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerRefund,
Amount: job.Power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: "luma",
Remark: fmt.Sprintf("Luma 任务失败退回算力。任务ID%sErr:%s", job.TaskId, job.ErrMsg),
CreatedAt: time.Now(),
}).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
// 删除文件
_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
_ = h.uploader.GetUploadHandler().Delete(job.VideoURL)
}
func (h *VideoHandler) Publish(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
publish := h.GetBool(c, "publish")
err := h.DB.Model(&model.VideoJob{}).Where("id", id).Where("user_id", userId).UpdateColumn("publish", publish).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}

View File

@ -24,6 +24,7 @@ import (
"geekai/service/sd"
"geekai/service/sms"
"geekai/service/suno"
"geekai/service/video"
"geekai/store"
"io"
"log"
@ -201,6 +202,13 @@ func main() {
s.CheckTaskNotify()
s.DownloadFiles()
}),
fx.Provide(video.NewService),
fx.Invoke(func(s *video.Service) {
s.Run()
s.SyncTaskProgress()
s.CheckTaskNotify()
s.DownloadFiles()
}),
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay),
@ -484,6 +492,15 @@ func main() {
group.GET("play", h.Play)
group.POST("lyric", h.Lyric)
}),
fx.Provide(handler.NewVideoHandler),
fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
group := s.Engine.Group("/api/video")
group.Any("client", h.Client)
group.POST("luma/create", h.LumaCreate)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}),
fx.Provide(handler.NewTestHandler),
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
group := s.Engine.Group("/api/test")

View File

@ -242,6 +242,10 @@ func (s *Service) Upload(task types.SunoTask) (RespVo, error) {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.StatusCode != 200 {
return RespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {

View File

@ -82,8 +82,8 @@ func (s *Service) Run() {
logger.Errorf("taking task with error: %v", err)
continue
}
var r RespVo
r, err = s.CreateLuma(task)
var r LumaRespVo
r, err = s.LumaCreate(task)
if err != nil {
logger.Errorf("create task with error: %v", err)
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
@ -96,14 +96,15 @@ func (s *Service) Run() {
// 更新任务信息
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"task_id": r.Id,
"channel": r.Channel,
"task_id": r.Id,
"channel": r.Channel,
"prompt_ext": r.Prompt,
})
}
}()
}
type RespVo struct {
type LumaRespVo struct {
Id string `json:"id"`
Prompt string `json:"prompt"`
State string `json:"state"`
@ -114,7 +115,7 @@ type RespVo struct {
Channel string `json:"channel,omitempty"`
}
func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
@ -123,7 +124,7 @@ func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return RespVo{}, errors.New("no available API KEY for Suno")
return LumaRespVo{}, errors.New("no available API KEY for Luma")
}
reqBody := map[string]interface{}{
@ -133,7 +134,7 @@ func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
"image_url": task.Params.StartImgURL,
"image_end_url": task.Params.EndImgURL,
}
var res RespVo
var res LumaRespVo
apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
@ -141,13 +142,17 @@ func (s *Service) CreateLuma(task types.VideoTask) (RespVo, error) {
SetBody(reqBody).
Post(apiURL)
if err != nil {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.StatusCode != 200 {
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return RespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
return LumaRespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
// update the last_use_at for api key
@ -180,7 +185,7 @@ func (s *Service) CheckTaskNotify() {
func (s *Service) DownloadFiles() {
go func() {
var items []model.SunoJob
var items []model.VideoJob
for {
res := s.db.Where("progress", 102).Find(&items)
if res.Error != nil {
@ -188,22 +193,13 @@ func (s *Service) DownloadFiles() {
}
for _, v := range items {
// 下载图片和音频
logger.Infof("try download cover image: %s", v.CoverURL)
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true)
if err != nil {
logger.Errorf("download image with error: %v", err)
continue
}
logger.Infof("try download audio: %s", v.AudioURL)
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true)
logger.Infof("try download video: %s", v.VideoURL)
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
if err != nil {
logger.Errorf("download audio with error: %v", err)
continue
}
v.CoverURL = coverURL
v.AudioURL = audioURL
v.VideoURL = videoURL
v.Progress = 100
s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
@ -217,7 +213,7 @@ func (s *Service) DownloadFiles() {
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.SunoJob
var jobs []model.VideoJob
for {
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
if res.Error != nil {
@ -225,60 +221,14 @@ func (s *Service) SyncTaskProgress() {
}
for _, job := range jobs {
task, err := s.QueryTask(job.TaskId, job.Channel)
task, err := s.QueryLumaTask(job.TaskId, job.Channel)
if err != nil {
logger.Errorf("query task with error: %v", err)
continue
}
if task.Code != "success" {
logger.Errorf("query task with error: %v", task.Message)
continue
}
logger.Debugf("task: %+v", task)
logger.Debugf("task: %+v", task.Data.Status)
// 任务完成,删除旧任务插入两条新任务
if task.Data.Status == "SUCCESS" {
var jobId = job.Id
var flag = false
tx := s.db.Begin()
for _, v := range task.Data.Data {
job.Id = 0
job.Progress = 102 // 102 表示资源未下载完成
job.Title = v.Title
job.SongId = v.Id
job.Duration = int(v.Metadata.Duration)
job.Prompt = v.Metadata.Prompt
job.Tags = v.Metadata.Tags
job.ModelName = v.ModelName
job.RawData = utils.JsonEncode(v)
job.CoverURL = v.ImageLargeUrl
job.AudioURL = v.AudioUrl
if err = tx.Create(&job).Error; err != nil {
logger.Error("create job with error: %v", err)
tx.Rollback()
break
}
flag = true
}
// 删除旧任务
if flag {
if err = tx.Delete(&model.SunoJob{}, "id = ?", jobId).Error; err != nil {
logger.Error("create job with error: %v", err)
tx.Rollback()
continue
}
}
tx.Commit()
} else if task.Data.FailReason != "" {
job.Progress = service.FailTaskProgress
job.ErrMsg = task.Data.FailReason
s.db.Updates(&job)
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
}
}
time.Sleep(time.Second * 10)
@ -286,42 +236,22 @@ func (s *Service) SyncTaskProgress() {
}()
}
type QueryRespVo struct {
Code string `json:"code"`
Message string `json:"message"`
Data struct {
TaskId string `json:"task_id"`
Action string `json:"action"`
Status string `json:"status"`
FailReason string `json:"fail_reason"`
SubmitTime int `json:"submit_time"`
StartTime int `json:"start_time"`
FinishTime int `json:"finish_time"`
Progress string `json:"progress"`
Data []struct {
Id string `json:"id"`
Title string `json:"title"`
Status string `json:"status"`
Metadata struct {
Tags string `json:"tags"`
Type string `json:"type"`
Prompt string `json:"prompt"`
Stream bool `json:"stream"`
Duration float64 `json:"duration"`
ErrorMessage interface{} `json:"error_message"`
} `json:"metadata"`
AudioUrl string `json:"audio_url"`
ImageUrl string `json:"image_url"`
VideoUrl string `json:"video_url"`
ModelName string `json:"model_name"`
DisplayName string `json:"display_name"`
ImageLargeUrl string `json:"image_large_url"`
MajorModelVersion string `json:"major_model_version"`
} `json:"data"`
} `json:"data"`
type LumaTaskVo struct {
Id string `json:"id"`
Liked interface{} `json:"liked"`
State string `json:"state"`
Video struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
DownloadUrl string `json:"download_url"`
} `json:"video"`
Prompt string `json:"prompt"`
CreatedAt time.Time `json:"created_at"`
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
}
func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) {
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
tx := s.db.Session(&gorm.Session{}).Where("type", "suno").
@ -329,22 +259,22 @@ func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error)
Where("enabled", true).
Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return QueryRespVo{}, errors.New("no available API KEY for Suno")
return LumaTaskVo{}, errors.New("no available API KEY for Suno")
}
apiURL := fmt.Sprintf("%s/suno/fetch/%s", apiKey.ApiURL, taskId)
var res QueryRespVo
apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
var res LumaTaskVo
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
if err != nil {
return QueryRespVo{}, fmt.Errorf("请求 API 失败:%v", err)
return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
}
defer r.Body.Close()
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return QueryRespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
return LumaTaskVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
return res, nil

View File

@ -547,7 +547,7 @@ const removeChat = function (chat) {
return e1.id === e2.id
})
//
newChat();
_newChat();
}).catch(e => {
ElMessage.error("操作失败:" + e.message);
})

View File

@ -302,6 +302,9 @@
<el-form-item label="Suno 算力" prop="suno_power">
<el-input v-model.number="system['suno_power']" placeholder="使用 Suno 生成一首音乐消耗算力"/>
</el-form-item>
<el-form-item label="Luma 算力" prop="luma_power">
<el-input v-model.number="system['luma_power']" placeholder="使用 Luma 生成一段视频消耗算力"/>
</el-form-item>
</el-tab-pane>
</el-tabs>