diff --git a/CHANGELOG.md b/CHANGELOG.md index 50a207e1..59a127ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,11 @@ * 功能优化:重构找回密码模块,支持通过手机或者邮箱找回密码 * 功能优化:管理后台给可以拖动排序的组件添加拖动图标 * 功能优化:Suno 支持合成完整歌曲,和上传自己的音乐作品进行二次创作 +* Bug修复:手机端角色和模型选择不生效 +* Bug修复:用户登录过期之后聊天页面出现大量报错,需要刷新页面才能正常 +* 功能优化:优化聊天页面 Websocket 断线重连代码,提高用户体验 +* 功能优化:给算力增减服务全部加上数据库事务和同步锁 +* 功能优化:支持用户在前端对话界面选择插件 * 功能新增:支持 Luma 文生视频功能 ## v4.1.2 diff --git a/api/core/app_server.go b/api/core/app_server.go index 92967067..0eebbc93 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -201,7 +201,6 @@ func needLogin(c *gin.Context) bool { c.Request.URL.Path == "/api/admin/logout" || c.Request.URL.Path == "/api/admin/login/captcha" || c.Request.URL.Path == "/api/user/register" || - c.Request.URL.Path == "/api/user/session" || c.Request.URL.Path == "/api/chat/history" || c.Request.URL.Path == "/api/chat/detail" || c.Request.URL.Path == "/api/chat/list" || diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 9464ec8b..42c86a2b 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -57,6 +57,7 @@ type ChatSession struct { ClientIP string `json:"client_ip"` // 客户端 IP ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段 Model ChatModel `json:"model"` // GPT 模型 + Tools string `json:"tools"` // 函数 } type ChatModel struct { diff --git a/api/core/types/config.go b/api/core/types/config.go index 05dec367..9638f620 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -150,8 +150,9 @@ type SystemConfig struct { MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力 MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力 SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力 - DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力 + DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力 SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力 + LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力 WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址 diff --git a/api/core/types/task.go b/api/core/types/task.go index 2affae7d..d41e592a 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -85,7 +85,6 @@ type SunoTask struct { Channel string `json:"channel"` UserId int `json:"user_id"` Type int `json:"type"` - TaskId string `json:"task_id"` Title string `json:"title"` RefTaskId string `json:"ref_task_id,omitempty"` RefSongId string `json:"ref_song_id,omitempty"` @@ -97,3 +96,30 @@ type SunoTask struct { SongId string `json:"song_id,omitempty"` // 合并歌曲ID AudioURL string `json:"audio_url"` // 用户上传音频地址 } + +const ( + VideoLuma = "luma" + VideoRunway = "runway" + VideoCog = "cog" +) + +type VideoTask struct { + Id uint `json:"id"` + Channel string `json:"channel"` + UserId int `json:"user_id"` + Type string `json:"type"` + TaskId string `json:"task_id"` + Prompt string `json:"prompt"` // 提示词 + Params VideoParams `json:"params"` +} + +type VideoParams struct { + PromptOptimize bool `json:"prompt_optimize"` // 是否优化提示词 + Loop bool `json:"loop"` // 是否循环参考图 + StartImgURL string `json:"start_img_url"` // 第一帧参考图地址 + EndImgURL string `json:"end_img_url"` // 最后一帧参考图地址 + Model string `json:"model"` // 使用哪个模型生成视频 + Radio string `json:"radio"` // 视频尺寸 + Style string `json:"style"` // 风格 + Duration int `json:"duration"` // 视频时长(秒) +} diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 6b1e9f8e..58b9870a 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -46,9 +46,10 @@ type ChatHandler struct { licenseService *service.LicenseService ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message + userService *service.UserService } -func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ChatHandler { +func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler { return &ChatHandler{ BaseHandler: handler.BaseHandler{App: app, DB: db}, redis: redis, @@ -56,6 +57,7 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manag licenseService: licenseService, ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), ChatContexts: types.NewLMap[string, []types.Message](), + userService: userService, } } @@ -71,6 +73,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { roleId := h.GetInt(c, "role_id", 0) chatId := c.Query("chat_id") modelId := h.GetInt(c, "model_id", 0) + tools := c.Query("tools") client := types.NewWsClient(ws) var chatRole model.ChatRole @@ -97,6 +100,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { SessionId: sessionId, ClientIP: c.ClientIP(), UserId: h.GetLoginUserId(c), + Tools: tools, } // use old chat data override the chat model and role ID @@ -209,34 +213,37 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio } req.Temperature = session.Model.Temperature req.MaxTokens = session.Model.MaxTokens - // OpenAI 支持函数功能 - var items []model.Function - res = h.DB.Where("enabled", true).Find(&items) - if res.Error == nil { - var tools = make([]types.Tool, 0) - for _, v := range items { - var parameters map[string]interface{} - err = utils.JsonDecode(v.Parameters, ¶meters) - if err != nil { - continue - } - tool := types.Tool{ - Type: "function", - Function: types.Function{ - Name: v.Name, - Description: v.Description, - Parameters: parameters, - }, - } - if v, ok := parameters["required"]; v == nil || !ok { - tool.Function.Parameters["required"] = []string{} - } - tools = append(tools, tool) - } - if len(tools) > 0 { - req.Tools = tools - req.ToolChoice = "auto" + if session.Tools != "" { + toolIds := strings.Split(session.Tools, ",") + var items []model.Function + res = h.DB.Where("enabled", true).Where("id IN ?", toolIds).Find(&items) + if res.Error == nil { + var tools = make([]types.Tool, 0) + for _, v := range items { + var parameters map[string]interface{} + err = utils.JsonDecode(v.Parameters, ¶meters) + if err != nil { + continue + } + tool := types.Tool{ + Type: "function", + Function: types.Function{ + Name: v.Name, + Description: v.Description, + Parameters: parameters, + }, + } + if v, ok := parameters["required"]; v == nil || !ok { + tool.Function.Parameters["required"] = []string{} + } + tools = append(tools, tool) + } + + if len(tools) > 0 { + req.Tools = tools + req.ToolChoice = "auto" + } } } @@ -270,7 +277,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model) tokens += tks + promptTokens - for _, v := range messages { + for i := len(messages) - 1; i >= 0; i-- { + v := messages[i] tks, _ := utils.CalcTokens(v.Content, req.Model) // 上下文 token 超出了模型的最大上下文长度 if tokens+tks >= session.Model.MaxContext { @@ -481,24 +489,15 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p if session.Model.Power > 0 { power = session.Model.Power } - res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power)) - if res.Error == nil { - // 记录算力消费日志 - var u model.User - h.DB.Where("id", userVo.Id).First(&u) - h.DB.Create(&model.PowerLog{ - UserId: userVo.Id, - Username: userVo.Username, - Type: types.PowerConsume, - Amount: power, - Mark: types.PowerSub, - Balance: u.Power, - Model: session.Model.Value, - Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens), - CreatedAt: time.Now(), - }) - } + err := h.userService.DecreasePower(int(userVo.Id), power, model.PowerLog{ + Type: types.PowerConsume, + Model: session.Model.Value, + Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens), + }) + if err != nil { + logger.Error(err) + } } func (h *ChatHandler) saveChatHistory( diff --git a/api/handler/dalle_handler.go b/api/handler/dalle_handler.go index d09f0651..bcf44ba8 100644 --- a/api/handler/dalle_handler.go +++ b/api/handler/dalle_handler.go @@ -11,32 +11,33 @@ import ( "fmt" "geekai/core" "geekai/core/types" + "geekai/service" "geekai/service/dalle" "geekai/service/oss" "geekai/store/model" "geekai/store/vo" "geekai/utils" "geekai/utils/resp" - "github.com/gorilla/websocket" - "net/http" - "time" - "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" + "github.com/gorilla/websocket" "gorm.io/gorm" + "net/http" ) type DallJobHandler struct { BaseHandler - redis *redis.Client - service *dalle.Service - uploader *oss.UploaderManager + redis *redis.Client + dallService *dalle.Service + uploader *oss.UploaderManager + userService *service.UserService } -func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler { +func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService) *DallJobHandler { return &DallJobHandler{ - service: service, - uploader: manager, + dallService: service, + uploader: manager, + userService: userService, BaseHandler: BaseHandler{ App: app, DB: db, @@ -61,14 +62,14 @@ func (h *DallJobHandler) Client(c *gin.Context) { } client := types.NewWsClient(ws) - h.service.Clients.Put(uint(userId), client) + h.dallService.Clients.Put(uint(userId), client) logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) go func() { for { _, msg, err := client.Receive() if err != nil { client.Close() - h.service.Clients.Delete(uint(userId)) + h.dallService.Clients.Delete(uint(userId)) return } @@ -127,7 +128,7 @@ func (h *DallJobHandler) Image(c *gin.Context) { return } - h.service.PushTask(types.DallTask{ + h.dallService.PushTask(types.DallTask{ JobId: job.Id, UserId: uint(userId), Prompt: data.Prompt, @@ -137,7 +138,7 @@ func (h *DallJobHandler) Image(c *gin.Context) { Power: job.Power, }) - client := h.service.Clients.Get(job.UserId) + client := h.dallService.Clients.Get(job.UserId) if client != nil { _ = client.Send([]byte("Task Updated")) } @@ -175,7 +176,7 @@ func (h *DallJobHandler) JobList(c *gin.Context) { } // JobList 获取任务列表 -func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) { +func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) { session := h.DB.Session(&gorm.Session{}) if finish { @@ -193,11 +194,14 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in offset := (page - 1) * pageSize session = session.Offset(offset).Limit(pageSize) } + // 统计总数 + var total int64 + session.Model(&model.DallJob{}).Count(&total) var items []model.DallJob res := session.Find(&items) if res.Error != nil { - return res.Error, nil + return res.Error, vo.Page{} } var jobs = make([]vo.DallJob, 0) @@ -210,7 +214,7 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in jobs = append(jobs, job) } - return nil, jobs + return nil, vo.NewPage(total, page, pageSize, jobs) } // Remove remove task image @@ -233,26 +237,11 @@ func (h *DallJobHandler) Remove(c *gin.Context) { // 如果任务未完成,或者任务失败,则恢复用户算力 if job.Progress != 100 { - err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - - var user model.User - tx.Where("id = ?", job.UserId).First(&user) - err = tx.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerRefund, - Amount: job.Power, - Balance: user.Power, - Mark: types.PowerAdd, - Model: "dall-e-3", - Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg), - CreatedAt: time.Now(), - }).Error + err := h.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: "dall-e-3", + Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg), + }) if err != nil { tx.Rollback() resp.ERROR(c, err.Error()) diff --git a/api/handler/function_handler.go b/api/handler/function_handler.go index 6917efde..f1838d4d 100644 --- a/api/handler/function_handler.go +++ b/api/handler/function_handler.go @@ -8,15 +8,16 @@ package handler // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( + "errors" + "fmt" "geekai/core" "geekai/core/types" "geekai/service/dalle" "geekai/service/oss" "geekai/store/model" + "geekai/store/vo" "geekai/utils" "geekai/utils/resp" - "errors" - "fmt" "strings" "time" @@ -224,3 +225,27 @@ func (h *FunctionHandler) Dall3(c *gin.Context) { resp.SUCCESS(c, content) } + +// List 获取所有的工具函数列表 +func (h *FunctionHandler) List(c *gin.Context) { + var items []model.Function + err := h.DB.Where("enabled", true).Find(&items).Error + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + tools := make([]vo.Function, 0) + for _, v := range items { + var f vo.Function + err = utils.CopyObject(v, &f) + if err != nil { + continue + } + f.Action = "" + f.Token = "" + tools = append(tools, f) + } + + resp.SUCCESS(c, tools) +} diff --git a/api/handler/markmap_handler.go b/api/handler/markmap_handler.go index 8196a81e..b4147deb 100644 --- a/api/handler/markmap_handler.go +++ b/api/handler/markmap_handler.go @@ -15,6 +15,7 @@ import ( "fmt" "geekai/core" "geekai/core/types" + "geekai/service" "geekai/store/model" "geekai/utils" "github.com/gin-gonic/gin" @@ -30,13 +31,15 @@ import ( // MarkMapHandler 生成思维导图 type MarkMapHandler struct { BaseHandler - clients *types.LMap[int, *types.WsClient] + clients *types.LMap[int, *types.WsClient] + userService *service.UserService } -func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler { +func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *MarkMapHandler { return &MarkMapHandler{ BaseHandler: BaseHandler{App: app, DB: db}, clients: types.NewLMap[int, *types.WsClient](), + userService: userService, } } @@ -185,22 +188,13 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode // 扣减算力 if chatModel.Power > 0 { - res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power)) - if res.Error == nil { - // 记录算力消费日志 - var u model.User - h.DB.Where("id", userId).First(&u) - h.DB.Create(&model.PowerLog{ - UserId: u.Id, - Username: u.Username, - Type: types.PowerConsume, - Amount: chatModel.Power, - Mark: types.PowerSub, - Balance: u.Power, - Model: chatModel.Value, - Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value), - CreatedAt: time.Now(), - }) + err = h.userService.DecreasePower(userId, chatModel.Power, model.PowerLog{ + Type: types.PowerConsume, + Model: chatModel.Value, + Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value), + }) + if err != nil { + return err } } diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 212729b2..34996c81 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -30,16 +30,18 @@ import ( type MidJourneyHandler struct { BaseHandler - service *mj.Service - snowflake *service.Snowflake - uploader *oss.UploaderManager + mjService *mj.Service + snowflake *service.Snowflake + uploader *oss.UploaderManager + userService *service.UserService } -func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager) *MidJourneyHandler { +func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService) *MidJourneyHandler { return &MidJourneyHandler{ - snowflake: snowflake, - service: service, - uploader: manager, + snowflake: snowflake, + mjService: service, + uploader: manager, + userService: userService, BaseHandler: BaseHandler{ App: app, DB: db, @@ -80,7 +82,7 @@ func (h *MidJourneyHandler) Client(c *gin.Context) { } client := types.NewWsClient(ws) - h.service.Clients.Put(uint(userId), client) + h.mjService.Clients.Put(uint(userId), client) logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) } @@ -196,7 +198,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { return } - h.service.PushTask(types.MjTask{ + h.mjService.PushTask(types.MjTask{ Id: job.Id, TaskId: taskId, Type: types.TaskType(data.TaskType), @@ -208,28 +210,22 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { Mode: h.App.SysConfig.MjMode, }) - client := h.service.Clients.Get(uint(job.UserId)) + client := h.mjService.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } // update user's power - tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) - // 记录算力变化日志 - if tx.Error == nil && tx.RowsAffected > 0 { - user, _ := h.GetLoginUser(c) - h.DB.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerConsume, - Amount: job.Power, - Balance: user.Power - job.Power, - Mark: types.PowerSub, - Model: "mid-journey", - Remark: fmt.Sprintf("%s操作,任务ID:%s", opt, job.TaskId), - CreatedAt: time.Now(), - }) + err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerConsume, + Model: "mid-journey", + Remark: fmt.Sprintf("%s操作,任务ID:%s", opt, job.TaskId), + }) + if err != nil { + resp.ERROR(c, err.Error()) + return } + resp.SUCCESS(c) } @@ -269,7 +265,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { return } - h.service.PushTask(types.MjTask{ + h.mjService.PushTask(types.MjTask{ Id: job.Id, Type: types.TaskUpscale, UserId: userId, @@ -280,27 +276,22 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { Mode: h.App.SysConfig.MjMode, }) - client := h.service.Clients.Get(uint(job.UserId)) + client := h.mjService.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } + // update user's power - tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) - // 记录算力变化日志 - if tx.Error == nil && tx.RowsAffected > 0 { - user, _ := h.GetLoginUser(c) - h.DB.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerConsume, - Amount: job.Power, - Balance: user.Power - job.Power, - Mark: types.PowerSub, - Model: "mid-journey", - Remark: fmt.Sprintf("Upscale 操作,任务ID:%s", job.TaskId), - CreatedAt: time.Now(), - }) + err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerConsume, + Model: "mid-journey", + Remark: fmt.Sprintf("Upscale 操作,任务ID:%s", job.TaskId), + }) + if err != nil { + resp.ERROR(c, err.Error()) + return } + resp.SUCCESS(c) } @@ -334,7 +325,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { return } - h.service.PushTask(types.MjTask{ + h.mjService.PushTask(types.MjTask{ Id: job.Id, Type: types.TaskVariation, UserId: userId, @@ -345,28 +336,21 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { Mode: h.App.SysConfig.MjMode, }) - client := h.service.Clients.Get(uint(job.UserId)) + client := h.mjService.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } - // update user's power - tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) - // 记录算力变化日志 - if tx.Error == nil && tx.RowsAffected > 0 { - user, _ := h.GetLoginUser(c) - h.DB.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerConsume, - Amount: job.Power, - Balance: user.Power - job.Power, - Mark: types.PowerSub, - Model: "mid-journey", - Remark: fmt.Sprintf("Variation 操作,任务ID:%s", job.TaskId), - CreatedAt: time.Now(), - }) + err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerConsume, + Model: "mid-journey", + Remark: fmt.Sprintf("Variation 操作,任务ID:%s", job.TaskId), + }) + if err != nil { + resp.ERROR(c, err.Error()) + return } + resp.SUCCESS(c) } @@ -401,7 +385,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) { } // JobList 获取 MJ 任务列表 -func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) { +func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) { session := h.DB.Session(&gorm.Session{}) if finish { session = session.Where("progress >= ?", 100).Order("id DESC") @@ -419,10 +403,14 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize session = session.Offset(offset).Limit(pageSize) } + // 统计总数 + var total int64 + session.Model(&model.MidJourneyJob{}).Count(&total) + var items []model.MidJourneyJob res := session.Find(&items) if res.Error != nil { - return res.Error, nil + return res.Error, vo.Page{} } var jobs = make([]vo.MidJourneyJob, 0) @@ -442,7 +430,7 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize jobs = append(jobs, job) } - return nil, jobs + return nil, vo.NewPage(total, page, pageSize, jobs) } // Remove remove task image @@ -465,25 +453,11 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) { // 如果任务未完成,或者任务失败,则恢复用户算力 if job.Progress != 100 { - err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - var user model.User - tx.Where("id = ?", job.UserId).First(&user) - err = tx.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerRefund, - Amount: job.Power, - Balance: user.Power, - Mark: types.PowerAdd, - Model: "mid-journey", - Remark: fmt.Sprintf("绘画任务失败,退回算力。任务ID:%s,Err: %s", job.TaskId, job.ErrMsg), - CreatedAt: time.Now(), - }).Error + err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: "mid-journey", + Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg), + }) if err != nil { tx.Rollback() resp.ERROR(c, err.Error()) @@ -498,7 +472,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) { logger.Error("remove image failed: ", err) } - client := h.service.Clients.Get(uint(job.UserId)) + client := h.mjService.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } diff --git a/api/handler/redeem_handler.go b/api/handler/redeem_handler.go index 4f557ce9..b6759ffd 100644 --- a/api/handler/redeem_handler.go +++ b/api/handler/redeem_handler.go @@ -11,6 +11,7 @@ import ( "fmt" "geekai/core" "geekai/core/types" + "geekai/service" "geekai/store/model" "geekai/utils/resp" "github.com/gin-gonic/gin" @@ -21,11 +22,12 @@ import ( type RedeemHandler struct { BaseHandler - lock sync.Mutex + lock sync.Mutex + userService *service.UserService } -func NewRedeemHandler(app *core.AppServer, db *gorm.DB) *RedeemHandler { - return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}} +func NewRedeemHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *RedeemHandler { + return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}, userService: userService} } func (h *RedeemHandler) Verify(c *gin.Context) { @@ -59,7 +61,11 @@ func (h *RedeemHandler) Verify(c *gin.Context) { } tx := h.DB.Begin() - err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power + ?", item.Power)).Error + err := h.userService.IncreasePower(int(userId), item.Power, model.PowerLog{ + Type: types.PowerRedeem, + Model: "兑换码", + Remark: fmt.Sprintf("兑换码核销,算力:%d,兑换码:%s...", item.Power, item.Code[:10]), + }) if err != nil { tx.Rollback() resp.ERROR(c, err.Error()) @@ -76,26 +82,6 @@ func (h *RedeemHandler) Verify(c *gin.Context) { return } - // 记录算力充值日志 - var user model.User - err = tx.Where("id", userId).First(&user).Error - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - - h.DB.Create(&model.PowerLog{ - UserId: userId, - Username: user.Username, - Type: types.PowerRedeem, - Amount: item.Power, - Balance: user.Power, - Mark: types.PowerAdd, - Model: "兑换码", - Remark: fmt.Sprintf("兑换码核销,算力:%d,兑换码:%s...", item.Power, item.Code[:10]), - CreatedAt: time.Now(), - }) tx.Commit() resp.SUCCESS(c) diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index 9cbc60fb..27568d39 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -31,19 +31,27 @@ import ( type SdJobHandler struct { BaseHandler - redis *redis.Client - service *sd.Service - uploader *oss.UploaderManager - snowflake *service.Snowflake - leveldb *store.LevelDB + redis *redis.Client + sdService *sd.Service + uploader *oss.UploaderManager + snowflake *service.Snowflake + leveldb *store.LevelDB + userService *service.UserService } -func NewSdJobHandler(app *core.AppServer, db *gorm.DB, service *sd.Service, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler { +func NewSdJobHandler(app *core.AppServer, + db *gorm.DB, + service *sd.Service, + manager *oss.UploaderManager, + snowflake *service.Snowflake, + userService *service.UserService, + levelDB *store.LevelDB) *SdJobHandler { return &SdJobHandler{ - service: service, - uploader: manager, - snowflake: snowflake, - leveldb: levelDB, + sdService: service, + uploader: manager, + snowflake: snowflake, + leveldb: levelDB, + userService: userService, BaseHandler: BaseHandler{ App: app, DB: db, @@ -68,7 +76,7 @@ func (h *SdJobHandler) Client(c *gin.Context) { } client := types.NewWsClient(ws) - h.service.Clients.Put(uint(userId), client) + h.sdService.Clients.Put(uint(userId), client) logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) } @@ -159,34 +167,27 @@ func (h *SdJobHandler) Image(c *gin.Context) { return } - h.service.PushTask(types.SdTask{ + h.sdService.PushTask(types.SdTask{ Id: int(job.Id), Type: types.TaskImage, Params: params, UserId: userId, }) - client := h.service.Clients.Get(uint(job.UserId)) + client := h.sdService.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } // update user's power - tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) - // 记录算力变化日志 - if tx.Error == nil && tx.RowsAffected > 0 { - user, _ := h.GetLoginUser(c) - h.DB.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerConsume, - Amount: job.Power, - Balance: user.Power - job.Power, - Mark: types.PowerSub, - Model: "stable-diffusion", - Remark: fmt.Sprintf("绘图操作,任务ID:%s", job.TaskId), - CreatedAt: time.Now(), - }) + err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerConsume, + Model: "stable-diffusion", + Remark: fmt.Sprintf("绘图操作,任务ID:%s", job.TaskId), + }) + if err != nil { + resp.ERROR(c, err.Error()) + return } resp.SUCCESS(c) @@ -223,7 +224,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) { } // JobList 获取 MJ 任务列表 -func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) { +func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) { session := h.DB.Session(&gorm.Session{}) if finish { @@ -242,10 +243,14 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, session = session.Offset(offset).Limit(pageSize) } + // 统计总数 + var total int64 + session.Model(&model.SdJob{}).Count(&total) + var items []model.SdJob res := session.Find(&items) if res.Error != nil { - return res.Error, nil + return res.Error, vo.Page{} } var jobs = make([]vo.SdJob, 0) @@ -267,7 +272,7 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, jobs = append(jobs, job) } - return nil, jobs + return nil, vo.NewPage(total, page, pageSize, jobs) } // Remove remove task image @@ -290,25 +295,11 @@ func (h *SdJobHandler) Remove(c *gin.Context) { // 如果任务未完成,或者任务失败,则恢复用户算力 if job.Progress != 100 { - err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - var user model.User - tx.Where("id = ?", job.UserId).First(&user) - err = tx.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerRefund, - Amount: job.Power, - Balance: user.Power, - Mark: types.PowerAdd, - Model: "stable-diffusion", - Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%s, Err: %s", job.TaskId, job.ErrMsg), - CreatedAt: time.Now(), - }).Error + err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: "stable-diffusion", + Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%s, Err: %s", job.TaskId, job.ErrMsg), + }) if err != nil { tx.Rollback() resp.ERROR(c, err.Error()) diff --git a/api/handler/suno_handler.go b/api/handler/suno_handler.go index 9151a0eb..624b703e 100644 --- a/api/handler/suno_handler.go +++ b/api/handler/suno_handler.go @@ -11,6 +11,7 @@ import ( "fmt" "geekai/core" "geekai/core/types" + "geekai/service" "geekai/service/oss" "geekai/service/suno" "geekai/store/model" @@ -26,18 +27,20 @@ import ( type SunoHandler struct { BaseHandler - service *suno.Service - uploader *oss.UploaderManager + sunoService *suno.Service + uploader *oss.UploaderManager + userService *service.UserService } -func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager) *SunoHandler { +func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService) *SunoHandler { return &SunoHandler{ BaseHandler: BaseHandler{ App: app, DB: db, }, - service: service, - uploader: uploader, + sunoService: service, + uploader: uploader, + userService: userService, } } @@ -58,7 +61,7 @@ func (h *SunoHandler) Client(c *gin.Context) { } client := types.NewWsClient(ws) - h.service.Clients.Put(uint(userId), client) + h.sunoService.Clients.Put(uint(userId), client) logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) } @@ -123,7 +126,7 @@ func (h *SunoHandler) Create(c *gin.Context) { } // 创建任务 - h.service.PushTask(types.SunoTask{ + h.sunoService.PushTask(types.SunoTask{ Id: job.Id, UserId: job.UserId, Type: job.Type, @@ -140,24 +143,17 @@ func (h *SunoHandler) Create(c *gin.Context) { }) // update user's power - tx = h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) - // 记录算力变化日志 - if tx.Error == nil && tx.RowsAffected > 0 { - user, _ := h.GetLoginUser(c) - h.DB.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerConsume, - Amount: job.Power, - Balance: user.Power - job.Power, - Mark: types.PowerSub, - Model: job.ModelName, - Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName), - CreatedAt: time.Now(), - }) + err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerConsume, + Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName), + CreatedAt: time.Now(), + }) + if err != nil { + resp.ERROR(c, err.Error()) + return } - client := h.service.Clients.Get(uint(job.UserId)) + client := h.sunoService.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } @@ -166,8 +162,8 @@ func (h *SunoHandler) Create(c *gin.Context) { func (h *SunoHandler) List(c *gin.Context) { userId := h.GetLoginUserId(c) - page := h.GetInt(c, "page", 0) - pageSize := h.GetInt(c, "page_size", 0) + page := h.GetInt(c, "page", 1) + pageSize := h.GetInt(c, "page_size", 20) session := h.DB.Session(&gorm.Session{}).Where("user_id", userId) // 统计总数 @@ -239,25 +235,11 @@ func (h *SunoHandler) Remove(c *gin.Context) { // 如果任务未完成,或者任务失败,则恢复用户算力 if job.Progress != 100 { - err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - var user model.User - tx.Where("id = ?", job.UserId).First(&user) - err = tx.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerRefund, - Amount: job.Power, - Balance: user.Power, - Mark: types.PowerAdd, - Model: job.ModelName, - Remark: fmt.Sprintf("Suno 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), - CreatedAt: time.Now(), - }).Error + err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: job.ModelName, + Remark: fmt.Sprintf("Suno 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), + }) if err != nil { tx.Rollback() resp.ERROR(c, err.Error()) diff --git a/api/handler/user_handler.go b/api/handler/user_handler.go index 837fecc3..096da7c5 100644 --- a/api/handler/user_handler.go +++ b/api/handler/user_handler.go @@ -34,6 +34,7 @@ type UserHandler struct { redis *redis.Client licenseService *service.LicenseService captcha *service.CaptchaService + userService *service.UserService } func NewUserHandler( @@ -42,6 +43,7 @@ func NewUserHandler( searcher *xdb.Searcher, client *redis.Client, captcha *service.CaptchaService, + userService *service.UserService, licenseService *service.LicenseService) *UserHandler { return &UserHandler{ BaseHandler: BaseHandler{DB: db, App: app}, @@ -49,6 +51,7 @@ func NewUserHandler( redis: client, captcha: captcha, licenseService: licenseService, + userService: userService, } } @@ -155,10 +158,9 @@ func (h *UserHandler) Register(c *gin.Context) { user.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)) } - res := h.DB.Create(&user) - if res.Error != nil { - resp.ERROR(c, "保存数据失败") - logger.Error(res.Error) + tx := h.DB.Begin() + if err := tx.Create(&user).Error; err != nil { + resp.ERROR(c, err.Error()) return } @@ -167,35 +169,35 @@ func (h *UserHandler) Register(c *gin.Context) { // 增加邀请数量 h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1)) if h.App.SysConfig.InvitePower > 0 { - h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower)) - // 记录邀请算力充值日志 - var inviter model.User - h.DB.Where("id", inviteCode.UserId).First(&inviter) - h.DB.Create(&model.PowerLog{ - UserId: inviter.Id, - Username: inviter.Username, - Type: types.PowerInvite, - Amount: h.App.SysConfig.InvitePower, - Balance: inviter.Power, - Mark: types.PowerAdd, - Model: "", - Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username), - CreatedAt: time.Now(), + err := h.userService.IncreasePower(int(inviteCode.UserId), h.App.SysConfig.InvitePower, model.PowerLog{ + Type: types.PowerInvite, + Model: "", + Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username), }) + if err != nil { + tx.Rollback() + resp.ERROR(c, err.Error()) + return + } } // 添加邀请记录 - h.DB.Create(&model.InviteLog{ + err := tx.Create(&model.InviteLog{ InviterId: inviteCode.UserId, UserId: user.Id, Username: user.Username, InviteCode: inviteCode.Code, Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower), - }) + }).Error + if err != nil { + tx.Rollback() + resp.ERROR(c, err.Error()) + return + } } + tx.Commit() _ = h.redis.Del(c, key) // 注册成功,删除短信验证码 - // 自动登录创建 token token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "user_id": user.Id, diff --git a/api/handler/video_handler.go b/api/handler/video_handler.go new file mode 100644 index 00000000..72a1db06 --- /dev/null +++ b/api/handler/video_handler.go @@ -0,0 +1,233 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core" + "geekai/core/types" + "geekai/service" + "geekai/service/oss" + "geekai/service/video" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "gorm.io/gorm" + "net/http" +) + +type VideoHandler struct { + BaseHandler + videoService *video.Service + uploader *oss.UploaderManager + userService *service.UserService +} + +func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService) *VideoHandler { + return &VideoHandler{ + BaseHandler: BaseHandler{ + App: app, + DB: db, + }, + videoService: service, + uploader: uploader, + userService: userService, + } +} + +// Client WebSocket 客户端,用于通知任务状态变更 +func (h *VideoHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + c.Abort() + return + } + + userId := h.GetInt(c, "user_id", 0) + if userId == 0 { + logger.Info("Invalid user ID") + c.Abort() + return + } + + client := types.NewWsClient(ws) + h.videoService.Clients.Put(uint(userId), client) + logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) +} + +func (h *VideoHandler) LumaCreate(c *gin.Context) { + + var data struct { + Prompt string `json:"prompt"` + FirstFrameImg string `json:"first_frame_img,omitempty"` + EndFrameImg string `json:"end_frame_img,omitempty"` + ExpandPrompt bool `json:"expand_prompt,omitempty"` + Loop bool `json:"loop,omitempty"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + if data.Prompt == "" { + resp.ERROR(c, "prompt is needed") + return + } + + userId := int(h.GetLoginUserId(c)) + params := types.VideoParams{ + PromptOptimize: data.ExpandPrompt, + Loop: data.Loop, + StartImgURL: data.FirstFrameImg, + EndImgURL: data.EndFrameImg, + } + // 插入数据库 + job := model.VideoJob{ + UserId: userId, + Type: types.VideoLuma, + Prompt: data.Prompt, + Power: h.App.SysConfig.LumaPower, + Params: utils.JsonEncode(params), + } + tx := h.DB.Create(&job) + if tx.Error != nil { + resp.ERROR(c, tx.Error.Error()) + return + } + + // 创建任务 + h.videoService.PushTask(types.VideoTask{ + Id: job.Id, + UserId: userId, + Type: types.VideoLuma, + Prompt: data.Prompt, + Params: params, + }) + + // update user's power + err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerConsume, + Model: "luma", + Remark: fmt.Sprintf("Luma 文生视频,任务ID:%d", job.Id), + }) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + client := h.videoService.Clients.Get(uint(job.UserId)) + if client != nil { + _ = client.Send([]byte("Task Updated")) + } + resp.SUCCESS(c) +} + +func (h *VideoHandler) List(c *gin.Context) { + userId := h.GetLoginUserId(c) + t := c.Query("type") + page := h.GetInt(c, "page", 1) + pageSize := h.GetInt(c, "page_size", 20) + all := h.GetBool(c, "all") + session := h.DB.Session(&gorm.Session{}).Where("user_id", userId) + if t != "" { + session = session.Where("type", t) + } + if all { + session = session.Where("publish", 0).Where("progress", 100) + } else { + session = session.Where("user_id", h.GetLoginUserId(c)) + } + // 统计总数 + var total int64 + session.Model(&model.VideoJob{}).Count(&total) + + if page > 0 && pageSize > 0 { + offset := (page - 1) * pageSize + session = session.Offset(offset).Limit(pageSize) + } + var list []model.VideoJob + err := session.Order("id desc").Find(&list).Error + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + // 转换为 VO + items := make([]vo.VideoJob, 0) + for _, v := range list { + var item vo.VideoJob + err = utils.CopyObject(v, &item) + if err != nil { + continue + } + item.CreatedAt = v.CreatedAt.Unix() + items = append(items, item) + } + + resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items)) +} + +func (h *VideoHandler) Remove(c *gin.Context) { + id := h.GetInt(c, "id", 0) + userId := h.GetLoginUserId(c) + var job model.VideoJob + err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error + if err != nil { + resp.ERROR(c, err.Error()) + return + } + // 删除任务 + tx := h.DB.Begin() + if err := tx.Delete(&job).Error; err != nil { + tx.Rollback() + resp.ERROR(c, err.Error()) + return + } + + // 如果任务未完成,或者任务失败,则恢复用户算力 + if job.Progress != 100 { + err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: "luma", + Remark: fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), + }) + if err != nil { + tx.Rollback() + resp.ERROR(c, err.Error()) + return + } + } + tx.Commit() + + // 删除文件 + _ = h.uploader.GetUploadHandler().Delete(job.CoverURL) + _ = h.uploader.GetUploadHandler().Delete(job.VideoURL) +} + +func (h *VideoHandler) Publish(c *gin.Context) { + id := h.GetInt(c, "id", 0) + userId := h.GetLoginUserId(c) + publish := h.GetBool(c, "publish") + var job model.VideoJob + err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + err = h.DB.Model(&job).UpdateColumn("publish", publish).Error + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c) +} diff --git a/api/main.go b/api/main.go index 70611a6b..274f67b2 100644 --- a/api/main.go +++ b/api/main.go @@ -24,6 +24,7 @@ import ( "geekai/service/sd" "geekai/service/sms" "geekai/service/suno" + "geekai/service/video" "geekai/store" "io" "log" @@ -199,9 +200,16 @@ func main() { s.Run() s.SyncTaskProgress() s.CheckTaskNotify() - s.DownloadImages() + s.DownloadFiles() }), - + fx.Provide(video.NewService), + fx.Invoke(func(s *video.Service) { + s.Run() + s.SyncTaskProgress() + s.CheckTaskNotify() + s.DownloadFiles() + }), + fx.Provide(service.NewUserService), fx.Provide(payment.NewAlipayService), fx.Provide(payment.NewHuPiPay), fx.Provide(payment.NewJPayService), @@ -425,6 +433,7 @@ func main() { group.POST("weibo", h.WeiBo) group.POST("zaobao", h.ZaoBao) group.POST("dalle3", h.Dall3) + group.GET("list", h.List) }), fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) { group := s.Engine.Group("/api/admin/chat/") @@ -484,6 +493,15 @@ func main() { group.GET("play", h.Play) group.POST("lyric", h.Lyric) }), + fx.Provide(handler.NewVideoHandler), + fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) { + group := s.Engine.Group("/api/video") + group.Any("client", h.Client) + group.POST("luma/create", h.LumaCreate) + group.GET("list", h.List) + group.GET("remove", h.Remove) + group.GET("publish", h.Publish) + }), fx.Provide(handler.NewTestHandler), fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) { group := s.Engine.Group("/api/test") diff --git a/api/service/dalle/service.go b/api/service/dalle/service.go index 5b1d4ab8..4ea1082e 100644 --- a/api/service/dalle/service.go +++ b/api/service/dalle/service.go @@ -35,9 +35,10 @@ type Service struct { taskQueue *store.RedisQueue notifyQueue *store.RedisQueue Clients *types.LMap[uint, *types.WsClient] // UserId => Client + userService *service.UserService } -func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { +func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service { return &Service{ httpClient: req.C().SetTimeout(time.Minute * 3), db: db, @@ -45,6 +46,7 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli), Clients: types.NewLMap[uint, *types.WsClient](), uploadManager: manager, + userService: userService, } } @@ -122,32 +124,23 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { return "", errors.New("insufficient of power") } - // 更新用户算力 - tx := s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power)) - // 记录算力变化日志 - if tx.Error == nil && tx.RowsAffected > 0 { - var u model.User - s.db.Where("id", user.Id).First(&u) - s.db.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerConsume, - Amount: task.Power, - Balance: u.Power, - Mark: types.PowerSub, - Model: "dall-e-3", - Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)), - CreatedAt: time.Now(), - }) + // 扣减算力 + err := s.userService.DecreasePower(int(user.Id), task.Power, model.PowerLog{ + Type: types.PowerConsume, + Model: "dall-e-3", + Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)), + }) + if err != nil { + return "", fmt.Errorf("error with decrease power: %v", err) } // get image generation API KEY var apiKey model.ApiKey - tx = s.db.Where("type", "dalle"). + err = s.db.Where("type", "dalle"). Where("enabled", true). - Order("last_used_at ASC").First(&apiKey) - if tx.Error != nil { - return "", fmt.Errorf("no available DALL-E api key: %v", tx.Error) + Order("last_used_at ASC").First(&apiKey).Error + if err != nil { + return "", fmt.Errorf("no available DALL-E api key: %v", err) } var res imgRes @@ -181,13 +174,13 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { // update the api key last use time s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) // update task progress - tx = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ + err = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ "progress": 100, "org_url": res.Data[0].Url, "prompt": prompt, - }) - if tx.Error != nil { - return "", fmt.Errorf("err with update database: %v", tx.Error) + }).Error + if err != nil { + return "", fmt.Errorf("err with update database: %v", err) } s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed}) diff --git a/api/service/suno/service.go b/api/service/suno/service.go index a49bcb47..e3e502dd 100644 --- a/api/service/suno/service.go +++ b/api/service/suno/service.go @@ -242,6 +242,10 @@ func (s *Service) Upload(task types.SunoTask) (RespVo, error) { return RespVo{}, fmt.Errorf("请求 API 出错:%v", err) } + if r.StatusCode != 200 { + return RespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String()) + } + body, _ := io.ReadAll(r.Body) err = json.Unmarshal(body, &res) if err != nil { @@ -279,7 +283,7 @@ func (s *Service) CheckTaskNotify() { }() } -func (s *Service) DownloadImages() { +func (s *Service) DownloadFiles() { go func() { var items []model.SunoJob for { @@ -425,11 +429,11 @@ type QueryRespVo struct { func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) { // 读取 API KEY var apiKey model.ApiKey - tx := s.db.Session(&gorm.Session{}).Where("type", "suno"). + err := s.db.Session(&gorm.Session{}).Where("type", "suno"). Where("api_url", channel). Where("enabled", true). - Order("last_used_at DESC").First(&apiKey) - if tx.Error != nil { + Order("last_used_at DESC").First(&apiKey).Error + if err != nil { return QueryRespVo{}, errors.New("no available API KEY for Suno") } diff --git a/api/service/user_service.go b/api/service/user_service.go new file mode 100644 index 00000000..ea086d01 --- /dev/null +++ b/api/service/user_service.go @@ -0,0 +1,83 @@ +package service + +import ( + "fmt" + "geekai/core/types" + "geekai/store/model" + "gorm.io/gorm" + "sync" + "time" +) + +type UserService struct { + db *gorm.DB + lock sync.Mutex +} + +func NewUserService(db *gorm.DB) *UserService { + return &UserService{db: db, lock: sync.Mutex{}} +} + +// IncreasePower 增加用户算力 +func (s *UserService) IncreasePower(userId int, power int, log model.PowerLog) error { + s.lock.Lock() + defer s.lock.Unlock() + + tx := s.db.Begin() + err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power + ?", power)).Error + if err != nil { + tx.Rollback() + return err + } + var user model.User + tx.Where("id", userId).First(&user) + err = tx.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: log.Type, + Amount: power, + Balance: user.Power, + Mark: types.PowerAdd, + Model: log.Model, + Remark: log.Remark, + CreatedAt: time.Now(), + }).Error + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} + +// DecreasePower 减少用户算力 +func (s *UserService) DecreasePower(userId int, power int, log model.PowerLog) error { + s.lock.Lock() + defer s.lock.Unlock() + + tx := s.db.Begin() + err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error + if err != nil { + tx.Rollback() + return fmt.Errorf("扣减算力失败:%v", err) + } + var user model.User + tx.Where("id", userId).First(&user) + err = tx.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: log.Type, + Amount: power, + Balance: user.Power, + Mark: types.PowerSub, + Model: log.Model, + Remark: log.Remark, + CreatedAt: time.Now(), + }).Error + if err != nil { + tx.Rollback() + return fmt.Errorf("记录算力日志失败:%v", err) + } + tx.Commit() + return nil +} diff --git a/api/service/video/luma.go b/api/service/video/luma.go new file mode 100644 index 00000000..8d71b3e4 --- /dev/null +++ b/api/service/video/luma.go @@ -0,0 +1,326 @@ +package video + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "encoding/json" + "errors" + "fmt" + "geekai/core/types" + logger2 "geekai/logger" + "geekai/service" + "geekai/service/oss" + "geekai/store" + "geekai/store/model" + "geekai/utils" + "github.com/go-redis/redis/v8" + "io" + "time" + + "github.com/imroc/req/v3" + "gorm.io/gorm" +) + +var logger = logger2.GetLogger() + +type Service struct { + httpClient *req.Client + db *gorm.DB + uploadManager *oss.UploaderManager + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + Clients *types.LMap[uint, *types.WsClient] // UserId => Client +} + +func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { + return &Service{ + httpClient: req.C().SetTimeout(time.Minute * 3), + db: db, + taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli), + notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli), + Clients: types.NewLMap[uint, *types.WsClient](), + uploadManager: manager, + } +} + +func (s *Service) PushTask(task types.VideoTask) { + logger.Infof("add a new Video task to the task list: %+v", task) + s.taskQueue.RPush(task) +} + +func (s *Service) Run() { + // 将数据库中未提交的人物加载到队列 + var jobs []model.VideoJob + s.db.Where("task_id", "").Find(&jobs) + for _, v := range jobs { + var params types.VideoParams + if err := utils.JsonDecode(v.Params, ¶ms); err != nil { + logger.Errorf("unmarshal params failed: %v", err) + continue + } + s.PushTask(types.VideoTask{ + Id: v.Id, + Channel: v.Channel, + UserId: v.UserId, + Type: v.Type, + TaskId: v.TaskId, + Prompt: v.Prompt, + Params: params, + }) + } + logger.Info("Starting Video job consumer...") + go func() { + for { + var task types.VideoTask + err := s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) + continue + } + var r LumaRespVo + r, err = s.LumaCreate(task) + if err != nil { + logger.Errorf("create task with error: %v", err) + s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ + "err_msg": err.Error(), + "progress": service.FailTaskProgress, + }) + s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed}) + continue + } + + // 更新任务信息 + err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ + "task_id": r.Id, + "channel": r.Channel, + "prompt_ext": r.Prompt, + }).Error + if err != nil { + logger.Errorf("update task with error: %v", err) + s.PushTask(task) + } + } + }() +} + +type LumaRespVo struct { + Id string `json:"id"` + Prompt string `json:"prompt"` + State string `json:"state"` + CreatedAt time.Time `json:"created_at"` + Video interface{} `json:"video"` + Liked interface{} `json:"liked"` + EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"` + Channel string `json:"channel,omitempty"` +} + +func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) { + // 读取 API KEY + var apiKey model.ApiKey + session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true) + if task.Channel != "" { + session = session.Where("api_url", task.Channel) + } + tx := session.Order("last_used_at DESC").First(&apiKey) + if tx.Error != nil { + return LumaRespVo{}, errors.New("no available API KEY for Luma") + } + + reqBody := map[string]interface{}{ + "user_prompt": task.Prompt, + "expand_prompt": task.Params.PromptOptimize, + "loop": task.Params.Loop, + "image_url": task.Params.StartImgURL, + "image_end_url": task.Params.EndImgURL, + } + var res LumaRespVo + apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL) + logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody) + r, err := req.C().R(). + SetHeader("Authorization", "Bearer "+apiKey.Value). + SetBody(reqBody). + Post(apiURL) + if err != nil { + return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err) + } + + if r.StatusCode != 200 && r.StatusCode != 201 { + return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String()) + } + + body, _ := io.ReadAll(r.Body) + err = json.Unmarshal(body, &res) + if err != nil { + return LumaRespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) + } + + // update the last_use_at for api key + apiKey.LastUsedAt = time.Now().Unix() + session.Updates(&apiKey) + res.Channel = apiKey.ApiURL + return res, nil +} + +func (s *Service) CheckTaskNotify() { + go func() { + logger.Info("Running Suno task notify checking ...") + for { + var message service.NotifyMessage + err := s.notifyQueue.LPop(&message) + if err != nil { + continue + } + client := s.Clients.Get(uint(message.UserId)) + if client == nil { + continue + } + err = client.Send([]byte(message.Message)) + if err != nil { + continue + } + } + }() +} + +func (s *Service) DownloadFiles() { + go func() { + var items []model.VideoJob + for { + res := s.db.Where("progress", 102).Find(&items) + if res.Error != nil { + continue + } + + for _, v := range items { + if v.WaterURL == "" { + continue + } + + logger.Infof("try download video: %s", v.WaterURL) + videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true) + if err != nil { + logger.Errorf("download video with error: %v", err) + continue + } + logger.Infof("download video success: %s", videoURL) + v.WaterURL = videoURL + + if v.VideoURL != "" { + logger.Infof("try download no water video: %s", v.VideoURL) + videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true) + if err != nil { + logger.Errorf("download video with error: %v", err) + continue + } + } + logger.Info("download no water video success: %s", videoURL) + v.VideoURL = videoURL + v.Progress = 100 + s.db.Updates(&v) + s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished}) + } + + time.Sleep(time.Second * 10) + } + }() +} + +// SyncTaskProgress 异步拉取任务 +func (s *Service) SyncTaskProgress() { + go func() { + var jobs []model.VideoJob + for { + res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs) + if res.Error != nil { + continue + } + + for _, job := range jobs { + task, err := s.QueryLumaTask(job.TaskId, job.Channel) + if err != nil { + logger.Errorf("query task with error: %v", err) + // 更新任务信息 + s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{ + "progress": service.FailTaskProgress, // 102 表示资源未下载完成, + "err_msg": err.Error(), + }) + continue + } + + logger.Debugf("task: %+v", task) + if task.State == "completed" { // 更新任务信息 + data := map[string]interface{}{ + "progress": 102, // 102 表示资源未下载完成, + "water_url": task.Video.Url, + "raw_data": utils.JsonEncode(task), + "prompt_ext": task.Prompt, + } + if task.Video.DownloadUrl != "" { + data["video_url"] = task.Video.DownloadUrl + } + err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error + if err != nil { + logger.Errorf("更新数据库失败:%v", err) + continue + } + } + + } + + time.Sleep(time.Second * 10) + } + }() +} + +type LumaTaskVo struct { + Id string `json:"id"` + Liked interface{} `json:"liked"` + State string `json:"state"` + Video struct { + Url string `json:"url"` + Width int `json:"width"` + Height int `json:"height"` + DownloadUrl string `json:"download_url"` + } `json:"video"` + Prompt string `json:"prompt"` + CreatedAt time.Time `json:"created_at"` + EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"` +} + +func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) { + // 读取 API KEY + var apiKey model.ApiKey + err := s.db.Session(&gorm.Session{}).Where("type", "luma"). + Where("api_url", channel). + Where("enabled", true). + Order("last_used_at DESC").First(&apiKey).Error + if err != nil { + return LumaTaskVo{}, errors.New("no available API KEY for Luma") + } + + apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId) + var res LumaTaskVo + r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL) + + if err != nil { + return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err) + } + defer r.Body.Close() + + if r.StatusCode != 200 { + return LumaTaskVo{}, fmt.Errorf("API 返回失败:%v", r.String()) + } + + body, _ := io.ReadAll(r.Body) + err = json.Unmarshal(body, &res) + if err != nil { + return LumaTaskVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body)) + } + + return res, nil +} diff --git a/api/service/xxl_job_service.go b/api/service/xxl_job_service.go index 2adecf1b..ef701730 100644 --- a/api/service/xxl_job_service.go +++ b/api/service/xxl_job_service.go @@ -81,54 +81,6 @@ func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (ms // 自动将 VIP 会员的算力补充到每月赠送的最大值 func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (msg string) { logger.Info("开始进行月底账号盘点...") - var users []model.User - res := e.db.Where("vip", 1).Where("status", 1).Find(&users) - if res.Error != nil { - return "No vip users found" - } - - var sysConfig model.Config - res = e.db.Where("marker", "system").First(&sysConfig) - if res.Error != nil { - return "error with get system config: " + res.Error.Error() - } - - var config types.SystemConfig - err := utils.JsonDecode(sysConfig.Config, &config) - if err != nil { - return "error with decode system config: " + err.Error() - } - - for _, u := range users { - // 处理过期的 VIP - if u.ExpiredTime > 0 && u.ExpiredTime <= time.Now().Unix() { - u.Vip = false - e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("vip", false) - continue - } - if u.Power < config.VipMonthPower { - power := config.VipMonthPower - u.Power - // update user - tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", power)) - // 记录算力变动日志 - if tx.Error == nil { - var user model.User - e.db.Where("id", u.Id).First(&user) - e.db.Create(&model.PowerLog{ - UserId: u.Id, - Username: u.Username, - Type: types.PowerRecharge, - Amount: power, - Mark: types.PowerAdd, - Balance: user.Power, - Model: "系统盘点", - Remark: fmt.Sprintf("VIP会员每月算力派发,:%d", config.VipMonthPower), - CreatedAt: time.Now(), - }) - } - } - } - logger.Info("月底盘点完成!") return "success" } diff --git a/api/store/model/video_job.go b/api/store/model/video_job.go new file mode 100644 index 00000000..5dc7cb3e --- /dev/null +++ b/api/store/model/video_job.go @@ -0,0 +1,27 @@ +package model + +import "time" + +type VideoJob struct { + Id uint `gorm:"primarykey;column:id"` + UserId int + Channel string // 频道 + Type string // luma,runway,cog + TaskId string + Prompt string // 提示词 + PromptExt string // 优化后提示词 + CoverURL string // 封面图 URL + VideoURL string // 无水印视频 URL + WaterURL string // 有水印视频 URL + Progress int // 任务进度 + Publish bool // 是否发布 + ErrMsg string // 错误信息 + RawData string // 原始数据 json + Power int // 消耗算力 + Params string // 任务参数 + CreatedAt time.Time +} + +func (VideoJob) TableName() string { + return "chatgpt_video_jobs" +} diff --git a/api/store/vo/suno_job.go b/api/store/vo/suno_job.go index 97a18a3a..70dca573 100644 --- a/api/store/vo/suno_job.go +++ b/api/store/vo/suno_job.go @@ -28,7 +28,3 @@ type SunoJob struct { PlayTimes int `json:"play_times"` // 播放次数 CreatedAt int64 `json:"created_at"` } - -func (SunoJob) TableName() string { - return "chatgpt_suno_jobs" -} diff --git a/api/store/vo/video_job.go b/api/store/vo/video_job.go new file mode 100644 index 00000000..3582c667 --- /dev/null +++ b/api/store/vo/video_job.go @@ -0,0 +1,23 @@ +package vo + +import "geekai/core/types" + +type VideoJob struct { + Id uint `json:"id"` + UserId int `json:"user_id"` + Channel string `json:"channel"` + Type string `json:"type"` + TaskId string `json:"task_id"` + Prompt string `json:"prompt"` // 提示词 + PromptExt string `json:"prompt_ext"` // 提示词 + CoverURL string `json:"cover_url"` // 封面图 URL + VideoURL string `json:"video_url"` // 无水印视频 URL + WaterURL string `json:"water_url"` // 有水印视频 URL + Progress int `json:"progress"` // 任务进度 + Publish bool `json:"publish"` // 是否发布 + ErrMsg string `json:"err_msg"` // 错误信息 + RawData map[string]interface{} `json:"raw_data"` // 原始数据 json + Power int `json:"power"` // 消耗算力 + Params types.VideoParams `json:"params"` // 任务参数 + CreatedAt int64 `json:"created_at"` +} diff --git a/database/update-v4.1.3.sql b/database/update-v4.1.3.sql index 05f5cd3f..22dda9ba 100644 --- a/database/update-v4.1.3.sql +++ b/database/update-v4.1.3.sql @@ -1,2 +1,27 @@ ALTER TABLE `chatgpt_users` ADD `mobile` CHAR(11) NULL COMMENT '手机号' AFTER `username`; -ALTER TABLE `chatgpt_users` ADD `email` VARCHAR(50) NULL COMMENT '邮箱地址' AFTER `mobile`; \ No newline at end of file +ALTER TABLE `chatgpt_users` ADD `email` VARCHAR(50) NULL COMMENT '邮箱地址' AFTER `mobile`; + +CREATE TABLE `chatgpt_video_jobs` ( + `id` int NOT NULL, + `user_id` int NOT NULL COMMENT '用户 ID', + `channel` varchar(100) NOT NULL COMMENT '渠道', + `task_id` varchar(100) NOT NULL COMMENT '任务 ID', + `type` varchar(20) DEFAULT NULL COMMENT '任务类型,luma,runway,cogvideo', + `prompt` varchar(2000) NOT NULL COMMENT '提示词', + `prompt_ext` varchar(2000) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '优化后提示词', + `cover_url` varchar(512) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '封面图地址', + `video_url` varchar(512) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '视频地址', + `water_url` varchar(512) DEFAULT NULL COMMENT '带水印的视频地址', + `progress` smallint DEFAULT '0' COMMENT '任务进度', + `publish` tinyint(1) NOT NULL COMMENT '是否发布', + `err_msg` varchar(255) DEFAULT NULL COMMENT '错误信息', + `raw_data` text COMMENT '原始数据', + `power` smallint NOT NULL DEFAULT '0' COMMENT '消耗算力', + `created_at` datetime NOT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='MidJourney 任务表'; + +ALTER TABLE `chatgpt_video_jobs`ADD PRIMARY KEY (`id`); + +ALTER TABLE `chatgpt_video_jobs` MODIFY `id` int NOT NULL AUTO_INCREMENT; + +ALTER TABLE `chatgpt_video_jobs` ADD `params` VARCHAR(512) NULL COMMENT '参数JSON' AFTER `raw_data`; \ No newline at end of file diff --git a/deploy/docker-compose.yaml b/deploy/docker-compose.yaml index 0030482a..03fa5a73 100644 --- a/deploy/docker-compose.yaml +++ b/deploy/docker-compose.yaml @@ -27,17 +27,17 @@ services: ports: - "6380:6379" - xxl-job-admin: - container_name: geekai-xxl-job-admin - image: registry.cn-shenzhen.aliyuncs.com/geekmaster/xxl-job-admin:2.4.0 - restart: always - ports: - - "8081:8080" - environment: - - PARAMS=--spring.config.location=/application.properties - volumes: - - ./logs/xxl-job:/data/applogs - - ./conf/xxl-job/application.properties:/application.properties +# xxl-job-admin: +# container_name: geekai-xxl-job-admin +# image: registry.cn-shenzhen.aliyuncs.com/geekmaster/xxl-job-admin:2.4.0 +# restart: always +# ports: +# - "8081:8080" +# environment: +# - PARAMS=--spring.config.location=/application.properties +# volumes: +# - ./logs/xxl-job:/data/applogs +# - ./conf/xxl-job/application.properties:/application.properties tika: image: registry.cn-shenzhen.aliyuncs.com/geekmaster/tika:latest @@ -46,14 +46,14 @@ services: ports: - "9998:9998" - midjourney-proxy: - image: registry.cn-shenzhen.aliyuncs.com/geekmaster/midjourney-proxy:2.6.2 - container_name: geekai-midjourney-proxy - restart: always - ports: - - "8082:8080" - volumes: - - ./conf/mj-proxy:/home/spring/config +# midjourney-proxy: +# image: registry.cn-shenzhen.aliyuncs.com/geekmaster/midjourney-proxy:2.6.2 +# container_name: geekai-midjourney-proxy +# restart: always +# ports: +# - "8082:8080" +# volumes: +# - ./conf/mj-proxy:/home/spring/config # 后端 API 程序 diff --git a/web/public/files/suno.mp3 b/web/public/files/suno.mp3 deleted file mode 100644 index 4aa2a844..00000000 Binary files a/web/public/files/suno.mp3 and /dev/null differ diff --git a/web/public/files/test.mp3 b/web/public/files/test.mp3 deleted file mode 100644 index 2a518c05..00000000 Binary files a/web/public/files/test.mp3 and /dev/null differ diff --git a/web/public/images/logo.png b/web/public/images/logo.png index 63f187ed..61441d01 100644 Binary files a/web/public/images/logo.png and b/web/public/images/logo.png differ diff --git a/web/src/assets/css/chat-plus.css b/web/src/assets/css/chat-plus.css deleted file mode 100644 index aa695bf9..00000000 --- a/web/src/assets/css/chat-plus.css +++ /dev/null @@ -1,334 +0,0 @@ -#app { - height: 100%; -} -#app .chat-page { - height: 100%; -} -#app .chat-page .el-aside { - padding: 10px; - width: var(--el-aside-width, 320px); -} -#app .chat-page .el-aside .chat-list { - display: flex; - flex-flow: column; - border-radius: 10px; - padding: 10px 0; -} -#app .chat-page .el-aside .chat-list .search-box { - flex-wrap: wrap; - padding: 10px 0; -} -#app .chat-page .el-aside .chat-list .search-box .search-input { - --el-input-bg-color: #363535; - --el-input-border-color: #464545; - --el-input-focus-border-color: #47fff1; - --el-input-hover-border-color: #2da39a; - box-shadow: none; -} -#app .chat-page .el-aside .chat-list ::-webkit-scrollbar { - width: 0; - height: 0; - background-color: transparent; -} -#app .chat-page .el-aside .chat-list .content { - width: 100%; - overflow-y: scroll; -} -#app .chat-page .el-aside .chat-list .content .chat-list-item { - display: flex; - width: 100%; - justify-content: flex-start; - padding: 8px 12px; - cursor: pointer; - border: 1px solid #3c3c3c; - margin-bottom: 6px; - border-radius: 5px; -} -#app .chat-page .el-aside .chat-list .content .chat-list-item:hover { - background-color: #343540; -} -#app .chat-page .el-aside .chat-list .content .chat-list-item .avatar { - width: 32px; - height: 32px; - border-radius: 50%; -} -#app .chat-page .el-aside .chat-list .content .chat-list-item .chat-title-input { - font-size: 14px; - margin-top: 4px; - margin-left: 10px; - overflow: hidden; - white-space: nowrap; - text-overflow: ellipsis; - width: 190px; -} -#app .chat-page .el-aside .chat-list .content .chat-list-item .chat-title { - color: #c1c1c1; - padding: 5px 10px; - max-width: 220px; - font-size: 14px; - overflow: hidden; - white-space: nowrap; - text-overflow: ellipsis; -} -#app .chat-page .el-aside .chat-list .content .chat-list-item .chat-opt { - position: absolute; - right: 2px; - top: 16px; - color: #fff; -} -#app .chat-page .el-aside .chat-list .content .chat-list-item .chat-opt .el-dropdown-link { - color: #fff; -} -#app .chat-page .el-aside .chat-list .content .chat-list-item .chat-opt .el-icon { - margin-right: 8px; -} -#app .chat-page .el-aside .chat-list .content .chat-list-item.active { - background-color: #343540; - border-color: #21aa93; -} -#app .chat-page .el-aside .tool-box { - display: flex; - justify-content: center; - padding-top: 12px; - border-top: 1px solid #3c3c3c; -} -#app .chat-page .el-aside .tool-box .iconfont { - margin-right: 5px; -} -#app .chat-page .el-main { - overflow: hidden; - --el-main-padding: 0; - margin: 0; -} -#app .chat-page .el-main .chat-container { - min-width: 0; - flex: 1; - background-color: var(--el-bg-color); - color: var(--el-text-color-primary); -} -#app .chat-page .el-main .chat-container .chat-config { - height: 30px; - padding: 10px 30px; - display: flex; - justify-content: center; - justify-items: center; - border-bottom: 1px solid #d9d9e3; -} -#app .chat-page .el-main .chat-container .chat-config .role-select-label { - color: #fff; -} -#app .chat-page .el-main .chat-container .chat-config .el-select { - max-width: 150px; - margin-right: 10px; -} -#app .chat-page .el-main .chat-container .chat-config .role-select { - max-width: 130px; -} -#app .chat-page .el-main .chat-container .chat-config .setting { - padding: 5px; - border-radius: 5px; - cursor: pointer; -} -#app .chat-page .el-main .chat-container .chat-config .setting .iconfont { - font-size: 18px; - color: #19c37d; -} -#app .chat-page .el-main .chat-container .chat-config .setting:hover { - background: #d5fad3; -} -#app .chat-page .el-main .chat-container .chat-config .el-button .el-icon { - margin-right: 5px; -} -#app .chat-page .el-main .chat-container #container { - overflow: hidden; - width: 100%; - position: relative; -} -#app .chat-page .el-main .chat-container #container ::-webkit-scrollbar { - width: 12px /* 滚动条宽度 */; - background: #f1f1f1; -} -#app .chat-page .el-main .chat-container #container ::-webkit-scrollbar-track { - background-color: #e1e1e1; -} -#app .chat-page .el-main .chat-container #container ::-webkit-scrollbar-thumb { - background-color: #c1c1c1; - border-radius: 12px; -} -#app .chat-page .el-main .chat-container #container ::-webkit-scrollbar-thumb:hover { - background-color: #a8a8a8; -} -#app .chat-page .el-main .chat-container #container .chat-box { - overflow-y: auto; - --content-font-size: 16px; - --content-color: #c1c1c1; - font-family: 'Microsoft YaHei', '微软雅黑', Arial, sans-serif; - padding: 0 0 50px 0; -} -#app .chat-page .el-main .chat-container #container .chat-box .chat-line { - font-size: 14px; - display: flex; - align-items: flex-start; -} -#app .chat-page .el-main .chat-container #container .input-box { - position: absolute; - bottom: 0; - width: 100%; -} -#app .chat-page .el-main .chat-container #container .input-box .input-box-inner { - display: flex; - background-color: #fff; - justify-content: center; - align-items: center; - box-shadow: 0 2px 15px rgba(0,0,0,0.1); - padding: 0 15px; -} -#app .chat-page .el-main .chat-container #container .input-box .input-box-inner .tool-item { - margin-right: 15px; - border-radius: 6px; - color: #19c37d; - display: flex; - justify-content: center; - justify-items: center; - padding: 6px; - cursor: pointer; - background: #f2f2f2; -} -#app .chat-page .el-main .chat-container #container .input-box .input-box-inner .tool-item:hover { - background: #d5fad3; -} -#app .chat-page .el-main .chat-container #container .input-box .input-box-inner .tool-item .iconfont { - font-size: 24px; -} -#app .chat-page .el-main .chat-container #container .input-box .input-box-inner .input-body { - width: 100%; - margin: 0; - border: none; - padding: 10px 0; - display: flex; - justify-content: center; - position: relative; -} -#app .chat-page .el-main .chat-container #container .input-box .input-box-inner .input-body .hide-div { - white-space: pre-wrap; /* 保持文本换行 */ - visibility: hidden; /* 隐藏 div */ - position: absolute; /* 脱离文档流 */ - line-height: 24px; - font-size: 14px; - word-wrap: break-word; /* 允许单词换行 */ - overflow-wrap: break-word; /* 允许长单词换行,适用于现代浏览器 */ -} -#app .chat-page .el-main .chat-container #container .input-box .input-box-inner .input-body .input-border { - display: flex; - width: 100%; - overflow: hidden; - border: 2px solid #21aa93; - border-radius: 10px; - padding: 10px; - background-color: #f4f4f4; -} -#app .chat-page .el-main .chat-container #container .input-box .input-box-inner .input-body .input-border .input-inner { - display: flex; - flex-flow: column; - width: 100%; -} -#app .chat-page .el-main .chat-container #container .input-box .input-box-inner .input-body .input-border .input-inner .file-list { - padding-bottom: 10px; -} -#app .chat-page .el-main .chat-container #container .input-box .input-box-inner .input-body .input-border .input-inner .prompt-input::-webkit-scrollbar { - width: 0; - height: 0; -} -#app .chat-page .el-main .chat-container #container .input-box .input-box-inner .input-body .input-border .input-inner .prompt-input { - width: 100%; - line-height: 24px; - border: none; - font-size: 14px; - background: none; - resize: none; - white-space: pre-wrap; /* 保持文本换行 */ - word-wrap: break-word; /* 允许单词换行 */ - overflow-wrap: break-word; /* 允许长单词换行,适用于现代浏览器 */ -} -#app .chat-page .el-main .chat-container #container .input-box .input-box-inner .input-body .input-border .send-btn { - width: 32px; - margin-left: 10px; -} -#app .chat-page .el-main .chat-container #container .input-box .input-box-inner .input-body .input-border .send-btn .el-button { - padding: 8px 5px; - border-radius: 6px; - font-size: 20px; -} -#app .chat-page .el-main .chat-container #container::-webkit-scrollbar { - width: 0; - height: 0; -} -#app .el-message-box { - width: 90%; - max-width: 420px; -} -#app .el-message { - min-width: 100px; - max-width: 600px; -} -.el-select-dropdown__wrap .el-select-dropdown__item .role-option { - display: flex; - flex-flow: row; - margin-top: 8px; -} -.el-select-dropdown__wrap .el-select-dropdown__item .role-option .el-image { - width: 20px; - height: 20px; - border-radius: 50%; -} -.el-select-dropdown__wrap .el-select-dropdown__item .role-option span { - margin-left: 5px; - height: 20px; - line-height: 20px; -} -.account { - display: flex; - background-color: #90ffc2; - color: #000; - width: 100%; - border-radius: 10px; - padding: 10px; -} -.account .vip-logo .el-image { - width: 40px; - height: 40px; - border-radius: 100%; - background-color: #fff; -} -.account .vip-info { - padding: 0 10px 0 10px; -} -.account .vip-info h4, -.account .vip-info p { - margin: 0; -} -.account .vip-info h4 { - font-weight: bold; - font-size: 16px; -} -.account .vip-info p { - color: #333; -} -.account .pay-btn { - width: 100%; - display: flex; - justify-content: right; - align-items: center; -} -.el-overlay-dialog .el-dialog .el-dialog__body .notice { - line-height: 1.8; - font-size: 16px; - overflow: auto; - height: 100%; -} -.dialog-service { - text-align: center; -} -.dialog-service .el-image { - width: 360px; -} diff --git a/web/src/assets/css/chat-plus.styl b/web/src/assets/css/chat-plus.styl index b1dbce43..a2d73018 100644 --- a/web/src/assets/css/chat-plus.styl +++ b/web/src/assets/css/chat-plus.styl @@ -160,13 +160,15 @@ $borderColor = #4676d0; padding 5px border-radius 5px cursor pointer + background-color #f2f2f2 + margin-right 10px .iconfont { font-size 18px color #19c37d } &:hover { - background #D5FAD3 + background-color #D5FAD3 } } @@ -427,4 +429,11 @@ $borderColor = #4676d0; .el-image { width 360px; } +} + +.tools-dropdown { + width auto + .el-icon { + margin-left 5px; + } } \ No newline at end of file diff --git a/web/src/assets/css/home.css b/web/src/assets/css/home.css deleted file mode 100644 index f037af08..00000000 --- a/web/src/assets/css/home.css +++ /dev/null @@ -1,154 +0,0 @@ -.home { - display: flex; - height: 100vh; - width: 100%; - flex-flow: column; -} -.home .header { - display: flex; - justify-content: space-between; - height: 50px; - line-height: 50px; - background-color: #1e1f22; - padding-right: 20px; -} -.home .header .banner { - display: flex; -} -.home .header .banner .logo { - display: flex; - padding: 5px; - cursor: pointer; -} -.home .header .banner .logo .el-image { - width: 48px; - height: 48px; - background-color: #fff; - border-radius: 50%; -} -.home .header .banner .title { - display: flex; - color: #fff; - font-size: 20px; - padding: 0 10px; -} -.home .header .navbar { - display: flex; - flex-flow: row; -} -.home .header .navbar .link-button { - margin-right: 15px; - color: #e1e1e1; - padding: 0 10px; -} -.home .header .navbar .link-button:hover { - background-color: #414141; -} -.home .header .navbar .link-button .iconfont { - font-size: 24px; -} -.home .header .navbar .user-info { - width: 100%; - padding: 5px 0; -} -.home .header .navbar .user-info .el-dropdown-link { - width: 100%; - cursor: pointer; - display: flex; -} -.home .header .navbar .user-info .el-dropdown-link .el-image { - width: 36px; - height: 36px; - border-radius: 50%; -} -.home .header .navbar .user-info .el-dropdown-link .el-icon { - color: #ccc; - line-height: 24px; -} -.home .main { - width: 100%; - display: flex; - flex-flow: row; -} -.home .main .navigator { - display: flex; - flex-flow: column; - width: 60px; - padding: 10px 1px; - border-right: 1px solid #3c3c3c; - background-color: #1e1f22; -} -.home .main .navigator .nav-items { - margin-top: 10px; - padding: 0 5px; -} -.home .main .navigator .nav-items li { - margin-bottom: 15px; - display: flex; - flex-flow: column; -} -.home .main .navigator .nav-items li a { - color: #dadbdc; - border-radius: 10px; - width: 48px; - height: 48px; - display: flex; - justify-content: center; - align-items: center; - cursor: pointer; - background-color: #414348; -} -.home .main .navigator .nav-items li a .el-image { - border-radius: 10px; -} -.home .main .navigator .nav-items li a .iconfont { - font-size: 20px; -} -.home .main .navigator .nav-items li a:hover, -.home .main .navigator .nav-items li a.active { - color: #47fff1; - background-color: #0f7a71; -} -.home .main .navigator .nav-items li .title { - font-size: 12px; - padding-top: 6px; - color: #e5e7eb; - text-align: center; - white-space: nowrap; /* 防止文本换行 */ - overflow: hidden; /* 隐藏溢出内容 */ - text-overflow: unset; /* 使用省略号表示溢出内容 */ -} -.home .main .navigator .nav-items li .active { - color: #47fff1; -} -.home .main .content { - width: 100%; - overflow: auto; - box-sizing: border-box; - background-color: #282c34; -} -.el-popper .more-menus li { - padding: 10px 15px; - cursor: pointer; - border-radius: 5px; - margin: 5px 0; -} -.el-popper .more-menus li .el-image { - position: relative; - top: 5px; - right: 5px; -} -.el-popper .more-menus li:hover { - background-color: #f1f1f1; -} -.el-popper .more-menus li.active { - background-color: #f1f1f1; -} -.el-popper .user-info-menu li a { - width: 100%; - justify-content: left; -} -.el-popper .user-info-menu li a:hover { - text-decoration: none !important; - color: var(--el-primary-text-color); -} diff --git a/web/src/assets/css/home.styl b/web/src/assets/css/home.styl index 41bd1257..b435bdb9 100644 --- a/web/src/assets/css/home.styl +++ b/web/src/assets/css/home.styl @@ -23,7 +23,6 @@ .el-image { width 48px height 48px - background-color #ffffff border-radius 50% } } diff --git a/web/src/assets/css/index.css b/web/src/assets/css/index.css deleted file mode 100644 index 3eaa8878..00000000 --- a/web/src/assets/css/index.css +++ /dev/null @@ -1,101 +0,0 @@ -.index-page { - margin: 0; - overflow: hidden; - color: #fff; - display: flex; - justify-content: center; - align-items: baseline; - padding-top: 150px; -} -.index-page .color-bg { - position: absolute; - top: 0; - left: 0; - width: 100vw; - height: 100vh; -} -.index-page .image-bg { - filter: blur(8px); - background-size: cover; - background-position: center; -} -.index-page .shadow { - box-shadow: rgba(0,0,0,0.3) 0px 0px 3px; -} -.index-page .shadow:hover { - box-shadow: rgba(0,0,0,0.3) 0px 0px 8px; -} -.index-page .menu-box { - position: absolute; - top: 0; - width: 100%; - display: flex; -} -.index-page .menu-box .el-menu { - padding: 0 30px; - width: 100%; - display: flex; - justify-content: space-between; - background: none; - border: none; -} -.index-page .menu-box .el-menu .menu-item { - display: flex; - padding: 20px 0; - color: #fff; -} -.index-page .menu-box .el-menu .menu-item .title { - font-size: 24px; - padding: 10px 10px 0 10px; -} -.index-page .menu-box .el-menu .menu-item .el-image { - height: 50px; - background-color: #fff; -} -.index-page .menu-box .el-menu .menu-item .el-button { - margin-left: 10px; -} -.index-page .menu-box .el-menu .menu-item .el-button span { - margin-left: 5px; -} -.index-page .content { - text-align: center; - position: relative; - display: flex; - flex-flow: column; - align-items: center; -} -.index-page .content h1 { - font-size: 5rem; - margin-bottom: 1rem; -} -.index-page .content p { - font-size: 1.5rem; - margin-bottom: 2rem; -} -.index-page .content .navs { - display: flex; - max-width: 900px; - padding: 20px; -} -.index-page .content .navs .el-space--horizontal { - justify-content: center; -} -.index-page .content .navs .nav-item { - width: 200px; -} -.index-page .content .navs .nav-item .el-button { - width: 100%; - padding: 25px 20px; - font-size: 1.3rem; - transition: all 0.3s ease; -} -.index-page .content .navs .nav-item .el-button .iconfont { - font-size: 24px; - margin-right: 10px; - position: relative; - top: -2px; -} -.index-page .footer .el-link__inner { - color: #fff; -} diff --git a/web/src/assets/css/index.styl b/web/src/assets/css/index.styl index f2e8aa0c..fbe52f78 100644 --- a/web/src/assets/css/index.styl +++ b/web/src/assets/css/index.styl @@ -54,9 +54,9 @@ padding 10px 10px 0 10px } - .el-image { + .logo { height 50px - background-color #ffffff + border-radius 50% } .el-button { diff --git a/web/src/assets/css/login.css b/web/src/assets/css/login.css deleted file mode 100644 index 0b015d74..00000000 --- a/web/src/assets/css/login.css +++ /dev/null @@ -1,99 +0,0 @@ -.bg { - position: fixed; - left: 0; - right: 0; - top: 0; - bottom: 0; - background-color: #313237; - background-image: url("~@/assets/img/login-bg.jpg"); - background-size: cover; - background-position: center; - background-repeat: repeat-y; -} -.main .contain { - position: fixed; - left: 50%; - top: 40%; - width: 90%; - max-width: 400px; - transform: translate(-50%, -50%); - padding: 20px 10px; - color: #fff; - border-radius: 10px; -} -.main .contain .logo { - text-align: center; -} -.main .contain .logo .el-image { - width: 120px; - cursor: pointer; - background-color: #fff; - border-radius: 50%; -} -.main .contain .header { - width: 100%; - margin-bottom: 24px; - font-size: 24px; - color: $white_v1; - letter-space: 2px; - text-align: center; - padding-top: 10px; -} -.main .contain .content { - width: 100%; - height: auto; - border-radius: 3px; -} -.main .contain .content .block { - margin-bottom: 16px; -} -.main .contain .content .block .el-input__inner { - border: 1px solid $gray-v6 !important; -} -.main .contain .content .block .el-input__inner .el-icon-user, -.main .contain .content .block .el-input__inner .el-icon-lock { - font-size: 20px; -} -.main .contain .content .btn-row { - padding-top: 10px; -} -.main .contain .content .btn-row .login-btn { - width: 100%; - font-size: 16px; - letter-spacing: 2px; -} -.main .contain .content .text-line { - justify-content: center; - padding-top: 10px; - font-size: 14px; -} -.main .contain .content .opt { - padding: 15px; -} -.main .contain .content .opt .el-col { - text-align: center; -} -.main .contain .content .divider { - border-top: 2px solid #c1c1c1; -} -.main .contain .content .clogin { - padding: 15px; - display: flex; - justify-content: center; -} -.main .contain .content .clogin .iconfont { - font-size: 20px; - background: #e9f1f6; - padding: 8px; - border-radius: 50%; - cursor: pointer; -} -.main .contain .content .clogin .iconfont.icon-wechat { - color: #0bc15f; -} -.main .footer { - color: #fff; -} -.main .footer .container { - padding: 20px; -} diff --git a/web/src/assets/css/login.styl b/web/src/assets/css/login.styl index 32d79ed0..9dfa7449 100644 --- a/web/src/assets/css/login.styl +++ b/web/src/assets/css/login.styl @@ -30,7 +30,6 @@ .el-image { width 120px; cursor pointer - background-color #ffffff border-radius 50% } } diff --git a/web/src/components/BindEmail.vue b/web/src/components/BindEmail.vue index 1ae9e773..719f9c70 100644 --- a/web/src/components/BindEmail.vue +++ b/web/src/components/BindEmail.vue @@ -41,7 +41,7 @@ import {computed, ref, watch} from "vue"; import SendMsg from "@/components/SendMsg.vue"; import {ElMessage} from "element-plus"; import {httpPost} from "@/utils/http"; -import {checkSession, removeUserInfo} from "@/store/cache"; +import {checkSession} from "@/store/cache"; const props = defineProps({ show: Boolean, @@ -76,7 +76,6 @@ const save = () => { } httpPost('/api/user/bind/email', form.value).then(() => { - removeUserInfo() ElMessage.success("绑定成功") emits('hide') }).catch(e => { diff --git a/web/src/components/BindMobile.vue b/web/src/components/BindMobile.vue index 41eceb3f..3f6e15e1 100644 --- a/web/src/components/BindMobile.vue +++ b/web/src/components/BindMobile.vue @@ -41,7 +41,7 @@ import {computed, ref, watch} from "vue"; import SendMsg from "@/components/SendMsg.vue"; import {ElMessage} from "element-plus"; import {httpPost} from "@/utils/http"; -import {checkSession, removeUserInfo} from "@/store/cache"; +import {checkSession} from "@/store/cache"; const props = defineProps({ show: Boolean, @@ -79,7 +79,6 @@ const save = () => { } httpPost('/api/user/bind/mobile', form.value).then(() => { - removeUserInfo() ElMessage.success("绑定成功") emits('hide') }).catch(e => { diff --git a/web/src/router.js b/web/src/router.js index 7d2a2908..5d0ff0ac 100644 --- a/web/src/router.js +++ b/web/src/router.js @@ -26,6 +26,12 @@ const routes = [ meta: {title: '创作中心'}, component: () => import('@/views/ChatPlus.vue'), }, + { + name: 'chat-id', + path: '/chat/:id', + meta: {title: '创作中心'}, + component: () => import('@/views/ChatPlus.vue'), + }, { name: 'image-mj', path: '/mj', diff --git a/web/src/store/cache.js b/web/src/store/cache.js index 46206354..368a770e 100644 --- a/web/src/store/cache.js +++ b/web/src/store/cache.js @@ -6,29 +6,15 @@ const adminDataKey = "ADMIN_INFO_CACHE_KEY" const systemInfoKey = "SYSTEM_INFO_CACHE_KEY" const licenseInfoKey = "LICENSE_INFO_CACHE_KEY" export function checkSession() { - const item = Storage.get(userDataKey) ?? {expire:0, data:null} - if (item.expire > Date.now()) { - return Promise.resolve(item.data) - } - return new Promise((resolve, reject) => { httpGet('/api/user/session').then(res => { - item.data = res.data - // cache expires after 10 secs - item.expire = Date.now() + 1000 * 30 - Storage.set(userDataKey, item) - resolve(item.data) + resolve(res.data) }).catch(e => { Storage.remove(userDataKey) reject(e) }) }) } - -export function removeUserInfo() { - Storage.remove(userDataKey) -} - export function checkAdminSession() { const item = Storage.get(adminDataKey) ?? {expire:0, data:null} if (item.expire > Date.now()) { @@ -63,7 +49,7 @@ export function getSystemInfo() { Storage.set(systemInfoKey, item) resolve(item.data) }).catch(err => { - resolve(err) + reject(err) }) }) } diff --git a/web/src/store/session.js b/web/src/store/session.js index 3f023555..4d4f9d84 100644 --- a/web/src/store/session.js +++ b/web/src/store/session.js @@ -1,6 +1,6 @@ import {randString} from "@/utils/libs"; import Storage from "good-storage"; -import {checkAdminSession, checkSession, removeAdminInfo, removeUserInfo} from "@/store/cache"; +import {removeAdminInfo} from "@/store/cache"; /** * storage handler @@ -24,7 +24,6 @@ export function setUserToken(token) { export function removeUserToken() { Storage.remove(UserTokenKey) - removeUserInfo() } export function getAdminToken() { diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index 1bb04204..4550c0fa 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -3,7 +3,7 @@
- + @@ -23,7 +23,7 @@
-
@@ -100,11 +100,25 @@ + + + + + - - - - + +
@@ -202,7 +216,7 @@ import {nextTick, onMounted, onUnmounted, ref, watch} from 'vue' import ChatPrompt from "@/components/ChatPrompt.vue"; import ChatReply from "@/components/ChatReply.vue"; -import {Delete, Edit, More, Plus, Promotion, Search, Share, VideoPause} from '@element-plus/icons-vue' +import {Delete, Edit, InfoFilled, More, Plus, Promotion, Search, Share, VideoPause} from '@element-plus/icons-vue' import 'highlight.js/styles/a11y-dark.css' import { isMobile, @@ -211,7 +225,7 @@ import { UUID } from "@/utils/libs"; import {ElMessage, ElMessageBox} from "element-plus"; -import {getSessionId, getUserToken, removeUserToken} from "@/store/session"; +import {getSessionId, getUserToken} from "@/store/session"; import {httpGet, httpPost} from "@/utils/http"; import {useRouter} from "vue-router"; import Clipboard from "clipboard"; @@ -230,15 +244,15 @@ const modelID = ref(0) const chatData = ref([]); const allChats = ref([]); // 会话列表 const chatList = ref(allChats.value); -const activeChat = ref({}); const mainWinHeight = ref(0); // 主窗口高度 const chatBoxHeight = ref(0); // 聊天内容框高度 const leftBoxHeight = ref(0); -const loading = ref(true); +const loading = ref(false); const loginUser = ref(null); const roles = ref([]); const router = useRouter(); const roleId = ref(0) +const chatId = ref(); const newChatItem = ref(null); const isLogin = ref(false) const showHello = ref(true) @@ -254,7 +268,15 @@ const listStyle = ref(store.chatListStyle) watch(() => store.chatListStyle, (newValue) => { listStyle.value = newValue }); +const tools = ref([]) +const toolSelected = ref([]) +const loadHistory = ref(false) +// 初始化 ChatID +chatId.value = router.currentRoute.value.params.id +if (!chatId.value) { + chatId.value = UUID() +} if (isMobile()) { router.replace("/mobile/chat") @@ -290,6 +312,13 @@ httpGet("/api/config/get?key=notice").then(res => { ElMessage.error("获取系统配置失败:" + e.message) }) +// 获取工具函数 +httpGet("/api/function/list").then(res => { + tools.value = res.data +}).catch(e => { + showMessageError("获取工具函数失败:" + e.message) +}) + onMounted(() => { resizeElement(); initData() @@ -351,7 +380,6 @@ const initData = () => { ElMessage.error("加载会话列表失败!") }) }).catch(() => { - loading.value = false // 加载模型 httpGet('/api/model/list',{id:roleId.value}).then(res => { models.value = res.data @@ -418,6 +446,7 @@ const resizeElement = function () { const _newChat = () => { if (isLogin.value) { + chatId.value = UUID() newChat() } } @@ -428,6 +457,7 @@ const newChat = () => { store.setShowLoginDialog(true) return; } + const role = getRoleById(roleId.value) showHello.value = role.key === 'gpt'; // if the role bind a model, disable model change @@ -457,9 +487,19 @@ const newChat = () => { edit: false, removing: false, }; - activeChat.value = {} //取消激活的会话高亮 showStopGenerate.value = false; - connect(null, roleId.value) + router.push(`/chat/${chatId.value}`) + loadHistory.value = true + connect() +} + +// 切换工具 +const changeTool = () => { + if (!isLogin.value) { + return; + } + loadHistory.value = false + socket.value.close() } @@ -470,16 +510,18 @@ const loadChat = function (chat) { return; } - if (activeChat.value['chat_id'] === chat.chat_id) { + if (chatId.value === chat.chat_id) { return; } - activeChat.value = chat newChatItem.value = null; roleId.value = chat.role_id; modelID.value = chat.model_id; + chatId.value = chat.chat_id; showStopGenerate.value = false; - connect(chat.chat_id, chat.role_id) + router.push(`/chat/${chatId.value}`) + loadHistory.value = true + socket.value.close() } // 编辑会话标题 @@ -487,7 +529,6 @@ const tmpChatTitle = ref(''); const editChatTitle = (chat) => { chat.edit = true; tmpChatTitle.value = chat.title; - console.log(chat.chat_id) nextTick(() => { document.getElementById('chat-' + chat.chat_id).focus() }) @@ -542,7 +583,7 @@ const removeChat = function (chat) { return e1.id === e2.id }) // 重置会话 - newChat(); + _newChat(); }).catch(e => { ElMessage.error("操作失败:" + e.message); }) @@ -557,23 +598,10 @@ const prompt = ref(''); const showStopGenerate = ref(false); // 停止生成 const lineBuffer = ref(''); // 输出缓冲行 const socket = ref(null); -const activelyClose = ref(false); // 主动关闭 const canSend = ref(true); -const heartbeatHandle = ref(null) const sessionId = ref("") -const connect = function (chat_id, role_id) { - let isNewChat = false; - if (!chat_id) { - isNewChat = true; - chat_id = UUID(); - } - // 先关闭已有连接 - if (socket.value !== null) { - activelyClose.value = true; - socket.value.close(); - } - - const _role = getRoleById(role_id); +const connect = function () { + const chatRole = getRoleById(roleId.value); // 初始化 WebSocket 对象 sessionId.value = getSessionId(); let host = process.env.VUE_APP_WS_HOST @@ -585,26 +613,15 @@ const connect = function (chat_id, role_id) { } } - const _socket = new WebSocket(host + `/api/chat/new?session_id=${sessionId.value}&role_id=${role_id}&chat_id=${chat_id}&model_id=${modelID.value}&token=${getUserToken()}`); + loading.value = true + const toolIds = toolSelected.value.join(',') + const _socket = new WebSocket(host + `/api/chat/new?session_id=${sessionId.value}&role_id=${roleId.value}&chat_id=${chatId.value}&model_id=${modelID.value}&token=${getUserToken()}&tools=${toolIds}`); _socket.addEventListener('open', () => { - chatData.value = []; // 初始化聊天数据 enableInput() - activelyClose.value = false; - - if (isNewChat) { // 加载打招呼信息 - loading.value = false; - chatData.value.push({ - chat_id: chat_id, - role_id: role_id, - type: "reply", - id: randString(32), - icon: _role['icon'], - content: _role['hello_msg'], - }) - ElMessage.success({message: "对话连接成功!", duration: 1000}) - } else { // 加载聊天记录 - loadChatHistory(chat_id); + if (loadHistory.value) { + loadChatHistory(chatId.value) } + loading.value = false }); _socket.addEventListener('message', event => { @@ -619,17 +636,16 @@ const connect = function (chat_id, role_id) { chatData.value.push({ type: "reply", id: randString(32), - icon: _role['icon'], + icon: chatRole['icon'], prompt:prePrompt, content: "", }); } else if (data.type === 'end') { // 消息接收完毕 // 追加当前会话到会话列表 - if (isNewChat && newChatItem.value !== null) { + if (newChatItem.value !== null) { newChatItem.value['title'] = tmpChatTitle.value; - newChatItem.value['chat_id'] = chat_id; + newChatItem.value['chat_id'] = chatId.value; chatList.value.unshift(newChatItem.value); - activeChat.value = newChatItem.value; newChatItem.value = null; // 只追加一次 } @@ -641,7 +657,7 @@ const connect = function (chat_id, role_id) { httpPost("/api/chat/tokens", { text: "", model: getModelValue(modelID.value), - chat_id: chat_id + chat_id: chatId.value, }).then(res => { reply['created_at'] = new Date().getTime(); reply['tokens'] = res.data; @@ -662,7 +678,7 @@ const connect = function (chat_id, role_id) { // 将聊天框的滚动条滑动到最底部 nextTick(() => { document.getElementById('chat-box').scrollTo(0, document.getElementById('chat-box').scrollHeight) - localStorage.setItem("chat_id", chat_id) + localStorage.setItem("chat_id", chatId.value) }) }; } @@ -673,18 +689,8 @@ const connect = function (chat_id, role_id) { }); _socket.addEventListener('close', () => { - if (activelyClose.value || socket.value === null) { // 忽略主动关闭 - return; - } - // 停止发送消息 disableInput(true) - loading.value = true; - checkSession().then(() => { - connect(chat_id, role_id) - }).catch(() => { - loading.value = true - showMessageError("会话已断开,刷新页面...") - }); + connect() }); socket.value = _socket; @@ -801,21 +807,20 @@ const clearAllChats = function () { }) } -const logout = function () { - activelyClose.value = true; - httpGet('/api/user/logout').then(() => { - removeUserToken() - router.push("/login") - }).catch(() => { - ElMessage.error('注销失败!'); - }) -} - const loadChatHistory = function (chatId) { + chatData.value = [] httpGet('/api/chat/history?chat_id=' + chatId).then(res => { const data = res.data - if (!data) { - loading.value = false + if (!data || data.length === 0) { // 加载打招呼信息 + const _role = getRoleById(roleId.value) + chatData.value.push({ + chat_id: chatId, + role_id: roleId.value, + type: "reply", + id: randString(32), + icon: _role['icon'], + content: _role['hello_msg'], + }) return } showHello.value = false @@ -829,7 +834,6 @@ const loadChatHistory = function (chatId) { nextTick(() => { document.getElementById('chat-box').scrollTo(0, document.getElementById('chat-box').scrollHeight) }) - loading.value = false }).catch(e => { // TODO: 显示重新加载按钮 ElMessage.error('加载聊天记录失败:' + e.message); @@ -882,7 +886,6 @@ const shareChat = (chat) => { } const url = location.protocol + '//' + location.host + '/chat/export?chat_id=' + chat.chat_id - // console.log(url) window.open(url, '_blank'); } diff --git a/web/src/views/Dalle.vue b/web/src/views/Dalle.vue index f5c2998b..4d37c82b 100644 --- a/web/src/views/Dalle.vue +++ b/web/src/views/Dalle.vue @@ -206,7 +206,7 @@ import {nextTick, onMounted, onUnmounted, ref} from "vue" import {Delete, InfoFilled, Picture} from "@element-plus/icons-vue"; import {httpGet, httpPost} from "@/utils/http"; -import {ElMessage, ElMessageBox, ElNotification} from "element-plus"; +import {ElMessage, ElMessageBox} from "element-plus"; import Clipboard from "clipboard"; import {checkSession, getSystemInfo} from "@/store/cache"; import {useSharedStore} from "@/store/sharedata"; @@ -338,7 +338,7 @@ const fetchRunningJobs = () => { } // 获取运行中的任务 httpGet(`/api/dall/jobs?finish=false`).then(res => { - runningJobs.value = res.data + runningJobs.value = res.data.items }).catch(e => { ElMessage.error("获取任务失败:" + e.message) }) @@ -356,10 +356,10 @@ const fetchFinishJobs = () => { page.value = page.value + 1 httpGet(`/api/dall/jobs?finish=true&page=${page.value}&page_size=${pageSize.value}`).then(res => { - if (res.data.length < pageSize.value) { + if (res.data.items.length < pageSize.value) { isOver.value = true } - const imageList = res.data + const imageList = res.data.items for (let i = 0; i < imageList.length; i++) { imageList[i]["img_thumb"] = imageList[i]["img_url"] + "?imageView2/4/w/300/h/0/q/75" } diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index 975266ec..c8a1a7cb 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -816,7 +816,7 @@ const fetchRunningJobs = () => { } httpGet(`/api/mj/jobs?finish=false`).then(res => { - const jobs = res.data + const jobs = res.data.items const _jobs = [] for (let i = 0; i < jobs.length; i++) { if (jobs[i].progress === 101) { @@ -853,7 +853,7 @@ const fetchFinishJobs = () => { page.value = page.value + 1 // 获取已完成的任务 httpGet(`/api/mj/jobs?finish=true&page=${page.value}&page_size=${pageSize.value}`).then(res => { - const jobs = res.data + const jobs = res.data.items for (let i = 0; i < jobs.length; i++) { if (jobs[i]['img_url'] !== "") { if (jobs[i].type === 'upscale' || jobs[i].type === 'swapFace') { diff --git a/web/src/views/ImageSd.vue b/web/src/views/ImageSd.vue index ab4a80aa..2b377022 100644 --- a/web/src/views/ImageSd.vue +++ b/web/src/views/ImageSd.vue @@ -549,7 +549,6 @@ const sdPower = ref(0) // 画一张 SD 图片消耗算力 const socket = ref(null) const userId = ref(0) -const heartbeatHandle = ref(null) const connect = () => { let host = process.env.VUE_APP_WS_HOST if (host === '') { @@ -637,7 +636,7 @@ const fetchRunningJobs = () => { // 获取运行中的任务 httpGet(`/api/sd/jobs?finish=0`).then(res => { - runningJobs.value = res.data + runningJobs.value = res.data.items }).catch(e => { ElMessage.error("获取任务失败:" + e.message) }) @@ -655,10 +654,10 @@ const fetchFinishJobs = () => { page.value = page.value + 1 httpGet(`/api/sd/jobs?finish=1&page=${page.value}&page_size=${pageSize.value}`).then(res => { - if (res.data.length < pageSize.value) { + if (res.data.items.length < pageSize.value) { isOver.value = true } - const imageList = res.data + const imageList = res.data.items for (let i = 0; i < imageList.length; i++) { imageList[i]["img_thumb"] = imageList[i]["img_url"] + "?imageView2/4/w/300/h/0/q/75" } diff --git a/web/src/views/ImagesWall.vue b/web/src/views/ImagesWall.vue index fa8c4d83..8f3b7533 100644 --- a/web/src/views/ImagesWall.vue +++ b/web/src/views/ImagesWall.vue @@ -355,13 +355,13 @@ const getNext = () => { } httpGet(`${url}?page=${page.value}&page_size=${pageSize.value}`).then(res => { loading.value = false - if (!res.data || res.data.length === 0) { + if (!res.data.items || res.data.items.length === 0) { isOver.value = true return } // 生成缩略图 - const imageList = res.data + const imageList = res.data.items for (let i = 0; i < imageList.length; i++) { imageList[i]["img_thumb"] = imageList[i]["img_url"] + "?imageView2/4/w/300/h/0/q/75" } diff --git a/web/src/views/Index.vue b/web/src/views/Index.vue index fdbd6e93..31573579 100644 --- a/web/src/views/Index.vue +++ b/web/src/views/Index.vue @@ -7,7 +7,7 @@ :ellipsis="false" >