diff --git a/CHANGELOG.md b/CHANGELOG.md index a190b0e3..bfd50612 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ * Bug修复:手机端角色和模型选择不生效 * Bug修复:用户登录过期之后聊天页面出现大量报错,需要刷新页面才能正常 * 功能优化:优化聊天页面 Websocket 断线重连代码,提高用户体验 +* 功能优化:给算力增减服务全部加上数据库事务和同步锁 * 功能新增:支持 Luma 文生视频功能 ## v4.1.2 diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index ef45883d..164c7fd9 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -46,9 +46,10 @@ type ChatHandler struct { licenseService *service.LicenseService ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message + userService *service.UserService } -func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ChatHandler { +func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler { return &ChatHandler{ BaseHandler: handler.BaseHandler{App: app, DB: db}, redis: redis, @@ -56,6 +57,7 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manag licenseService: licenseService, ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), ChatContexts: types.NewLMap[string, []types.Message](), + userService: userService, } } @@ -482,24 +484,15 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p if session.Model.Power > 0 { power = session.Model.Power } - res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power)) - if res.Error == nil { - // 记录算力消费日志 - var u model.User - h.DB.Where("id", userVo.Id).First(&u) - h.DB.Create(&model.PowerLog{ - UserId: userVo.Id, - Username: userVo.Username, - Type: types.PowerConsume, - Amount: power, - Mark: types.PowerSub, - Balance: u.Power, - Model: session.Model.Value, - Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens), - CreatedAt: time.Now(), - }) - } + err := h.userService.DecreasePower(int(userVo.Id), power, model.PowerLog{ + Type: types.PowerConsume, + Model: session.Model.Value, + Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens), + }) + if err != nil { + logger.Error(err) + } } func (h *ChatHandler) saveChatHistory( diff --git a/api/handler/dalle_handler.go b/api/handler/dalle_handler.go index d09f0651..c1e9b664 100644 --- a/api/handler/dalle_handler.go +++ b/api/handler/dalle_handler.go @@ -11,32 +11,33 @@ import ( "fmt" "geekai/core" "geekai/core/types" + "geekai/service" "geekai/service/dalle" "geekai/service/oss" "geekai/store/model" "geekai/store/vo" "geekai/utils" "geekai/utils/resp" - "github.com/gorilla/websocket" - "net/http" - "time" - "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" + "github.com/gorilla/websocket" "gorm.io/gorm" + "net/http" ) type DallJobHandler struct { BaseHandler - redis *redis.Client - service *dalle.Service - uploader *oss.UploaderManager + redis *redis.Client + dallService *dalle.Service + uploader *oss.UploaderManager + userService *service.UserService } -func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler { +func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService) *DallJobHandler { return &DallJobHandler{ - service: service, - uploader: manager, + dallService: service, + uploader: manager, + userService: userService, BaseHandler: BaseHandler{ App: app, DB: db, @@ -61,14 +62,14 @@ func (h *DallJobHandler) Client(c *gin.Context) { } client := types.NewWsClient(ws) - h.service.Clients.Put(uint(userId), client) + h.dallService.Clients.Put(uint(userId), client) logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) go func() { for { _, msg, err := client.Receive() if err != nil { client.Close() - h.service.Clients.Delete(uint(userId)) + h.dallService.Clients.Delete(uint(userId)) return } @@ -127,7 +128,7 @@ func (h *DallJobHandler) Image(c *gin.Context) { return } - h.service.PushTask(types.DallTask{ + h.dallService.PushTask(types.DallTask{ JobId: job.Id, UserId: uint(userId), Prompt: data.Prompt, @@ -137,7 +138,7 @@ func (h *DallJobHandler) Image(c *gin.Context) { Power: job.Power, }) - client := h.service.Clients.Get(job.UserId) + client := h.dallService.Clients.Get(job.UserId) if client != nil { _ = client.Send([]byte("Task Updated")) } @@ -233,26 +234,11 @@ func (h *DallJobHandler) Remove(c *gin.Context) { // 如果任务未完成,或者任务失败,则恢复用户算力 if job.Progress != 100 { - err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - - var user model.User - tx.Where("id = ?", job.UserId).First(&user) - err = tx.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerRefund, - Amount: job.Power, - Balance: user.Power, - Mark: types.PowerAdd, - Model: "dall-e-3", - Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg), - CreatedAt: time.Now(), - }).Error + err := h.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: "dall-e-3", + Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg), + }) if err != nil { tx.Rollback() resp.ERROR(c, err.Error()) diff --git a/api/handler/markmap_handler.go b/api/handler/markmap_handler.go index 8196a81e..b4147deb 100644 --- a/api/handler/markmap_handler.go +++ b/api/handler/markmap_handler.go @@ -15,6 +15,7 @@ import ( "fmt" "geekai/core" "geekai/core/types" + "geekai/service" "geekai/store/model" "geekai/utils" "github.com/gin-gonic/gin" @@ -30,13 +31,15 @@ import ( // MarkMapHandler 生成思维导图 type MarkMapHandler struct { BaseHandler - clients *types.LMap[int, *types.WsClient] + clients *types.LMap[int, *types.WsClient] + userService *service.UserService } -func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler { +func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *MarkMapHandler { return &MarkMapHandler{ BaseHandler: BaseHandler{App: app, DB: db}, clients: types.NewLMap[int, *types.WsClient](), + userService: userService, } } @@ -185,22 +188,13 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode // 扣减算力 if chatModel.Power > 0 { - res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power)) - if res.Error == nil { - // 记录算力消费日志 - var u model.User - h.DB.Where("id", userId).First(&u) - h.DB.Create(&model.PowerLog{ - UserId: u.Id, - Username: u.Username, - Type: types.PowerConsume, - Amount: chatModel.Power, - Mark: types.PowerSub, - Balance: u.Power, - Model: chatModel.Value, - Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value), - CreatedAt: time.Now(), - }) + err = h.userService.DecreasePower(userId, chatModel.Power, model.PowerLog{ + Type: types.PowerConsume, + Model: chatModel.Value, + Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value), + }) + if err != nil { + return err } } diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 212729b2..342b34fc 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -30,16 +30,18 @@ import ( type MidJourneyHandler struct { BaseHandler - service *mj.Service - snowflake *service.Snowflake - uploader *oss.UploaderManager + mjService *mj.Service + snowflake *service.Snowflake + uploader *oss.UploaderManager + userService *service.UserService } -func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager) *MidJourneyHandler { +func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService) *MidJourneyHandler { return &MidJourneyHandler{ - snowflake: snowflake, - service: service, - uploader: manager, + snowflake: snowflake, + mjService: service, + uploader: manager, + userService: userService, BaseHandler: BaseHandler{ App: app, DB: db, @@ -80,7 +82,7 @@ func (h *MidJourneyHandler) Client(c *gin.Context) { } client := types.NewWsClient(ws) - h.service.Clients.Put(uint(userId), client) + h.mjService.Clients.Put(uint(userId), client) logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) } @@ -196,7 +198,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { return } - h.service.PushTask(types.MjTask{ + h.mjService.PushTask(types.MjTask{ Id: job.Id, TaskId: taskId, Type: types.TaskType(data.TaskType), @@ -208,28 +210,22 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { Mode: h.App.SysConfig.MjMode, }) - client := h.service.Clients.Get(uint(job.UserId)) + client := h.mjService.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } // update user's power - tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) - // 记录算力变化日志 - if tx.Error == nil && tx.RowsAffected > 0 { - user, _ := h.GetLoginUser(c) - h.DB.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerConsume, - Amount: job.Power, - Balance: user.Power - job.Power, - Mark: types.PowerSub, - Model: "mid-journey", - Remark: fmt.Sprintf("%s操作,任务ID:%s", opt, job.TaskId), - CreatedAt: time.Now(), - }) + err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerConsume, + Model: "mid-journey", + Remark: fmt.Sprintf("%s操作,任务ID:%s", opt, job.TaskId), + }) + if err != nil { + resp.ERROR(c, err.Error()) + return } + resp.SUCCESS(c) } @@ -269,7 +265,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { return } - h.service.PushTask(types.MjTask{ + h.mjService.PushTask(types.MjTask{ Id: job.Id, Type: types.TaskUpscale, UserId: userId, @@ -280,27 +276,22 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { Mode: h.App.SysConfig.MjMode, }) - client := h.service.Clients.Get(uint(job.UserId)) + client := h.mjService.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } + // update user's power - tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) - // 记录算力变化日志 - if tx.Error == nil && tx.RowsAffected > 0 { - user, _ := h.GetLoginUser(c) - h.DB.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerConsume, - Amount: job.Power, - Balance: user.Power - job.Power, - Mark: types.PowerSub, - Model: "mid-journey", - Remark: fmt.Sprintf("Upscale 操作,任务ID:%s", job.TaskId), - CreatedAt: time.Now(), - }) + err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerConsume, + Model: "mid-journey", + Remark: fmt.Sprintf("Upscale 操作,任务ID:%s", job.TaskId), + }) + if err != nil { + resp.ERROR(c, err.Error()) + return } + resp.SUCCESS(c) } @@ -334,7 +325,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { return } - h.service.PushTask(types.MjTask{ + h.mjService.PushTask(types.MjTask{ Id: job.Id, Type: types.TaskVariation, UserId: userId, @@ -345,28 +336,21 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { Mode: h.App.SysConfig.MjMode, }) - client := h.service.Clients.Get(uint(job.UserId)) + client := h.mjService.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } - // update user's power - tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) - // 记录算力变化日志 - if tx.Error == nil && tx.RowsAffected > 0 { - user, _ := h.GetLoginUser(c) - h.DB.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerConsume, - Amount: job.Power, - Balance: user.Power - job.Power, - Mark: types.PowerSub, - Model: "mid-journey", - Remark: fmt.Sprintf("Variation 操作,任务ID:%s", job.TaskId), - CreatedAt: time.Now(), - }) + err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerConsume, + Model: "mid-journey", + Remark: fmt.Sprintf("Variation 操作,任务ID:%s", job.TaskId), + }) + if err != nil { + resp.ERROR(c, err.Error()) + return } + resp.SUCCESS(c) } @@ -465,25 +449,11 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) { // 如果任务未完成,或者任务失败,则恢复用户算力 if job.Progress != 100 { - err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - var user model.User - tx.Where("id = ?", job.UserId).First(&user) - err = tx.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerRefund, - Amount: job.Power, - Balance: user.Power, - Mark: types.PowerAdd, - Model: "mid-journey", - Remark: fmt.Sprintf("绘画任务失败,退回算力。任务ID:%s,Err: %s", job.TaskId, job.ErrMsg), - CreatedAt: time.Now(), - }).Error + err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: "mid-journey", + Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg), + }) if err != nil { tx.Rollback() resp.ERROR(c, err.Error()) @@ -498,7 +468,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) { logger.Error("remove image failed: ", err) } - client := h.service.Clients.Get(uint(job.UserId)) + client := h.mjService.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } diff --git a/api/handler/redeem_handler.go b/api/handler/redeem_handler.go index 4f557ce9..b6759ffd 100644 --- a/api/handler/redeem_handler.go +++ b/api/handler/redeem_handler.go @@ -11,6 +11,7 @@ import ( "fmt" "geekai/core" "geekai/core/types" + "geekai/service" "geekai/store/model" "geekai/utils/resp" "github.com/gin-gonic/gin" @@ -21,11 +22,12 @@ import ( type RedeemHandler struct { BaseHandler - lock sync.Mutex + lock sync.Mutex + userService *service.UserService } -func NewRedeemHandler(app *core.AppServer, db *gorm.DB) *RedeemHandler { - return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}} +func NewRedeemHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *RedeemHandler { + return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}, userService: userService} } func (h *RedeemHandler) Verify(c *gin.Context) { @@ -59,7 +61,11 @@ func (h *RedeemHandler) Verify(c *gin.Context) { } tx := h.DB.Begin() - err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power + ?", item.Power)).Error + err := h.userService.IncreasePower(int(userId), item.Power, model.PowerLog{ + Type: types.PowerRedeem, + Model: "兑换码", + Remark: fmt.Sprintf("兑换码核销,算力:%d,兑换码:%s...", item.Power, item.Code[:10]), + }) if err != nil { tx.Rollback() resp.ERROR(c, err.Error()) @@ -76,26 +82,6 @@ func (h *RedeemHandler) Verify(c *gin.Context) { return } - // 记录算力充值日志 - var user model.User - err = tx.Where("id", userId).First(&user).Error - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - - h.DB.Create(&model.PowerLog{ - UserId: userId, - Username: user.Username, - Type: types.PowerRedeem, - Amount: item.Power, - Balance: user.Power, - Mark: types.PowerAdd, - Model: "兑换码", - Remark: fmt.Sprintf("兑换码核销,算力:%d,兑换码:%s...", item.Power, item.Code[:10]), - CreatedAt: time.Now(), - }) tx.Commit() resp.SUCCESS(c) diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index 9cbc60fb..3c20ef33 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -31,19 +31,27 @@ import ( type SdJobHandler struct { BaseHandler - redis *redis.Client - service *sd.Service - uploader *oss.UploaderManager - snowflake *service.Snowflake - leveldb *store.LevelDB + redis *redis.Client + sdService *sd.Service + uploader *oss.UploaderManager + snowflake *service.Snowflake + leveldb *store.LevelDB + userService *service.UserService } -func NewSdJobHandler(app *core.AppServer, db *gorm.DB, service *sd.Service, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler { +func NewSdJobHandler(app *core.AppServer, + db *gorm.DB, + service *sd.Service, + manager *oss.UploaderManager, + snowflake *service.Snowflake, + userService *service.UserService, + levelDB *store.LevelDB) *SdJobHandler { return &SdJobHandler{ - service: service, - uploader: manager, - snowflake: snowflake, - leveldb: levelDB, + sdService: service, + uploader: manager, + snowflake: snowflake, + leveldb: levelDB, + userService: userService, BaseHandler: BaseHandler{ App: app, DB: db, @@ -68,7 +76,7 @@ func (h *SdJobHandler) Client(c *gin.Context) { } client := types.NewWsClient(ws) - h.service.Clients.Put(uint(userId), client) + h.sdService.Clients.Put(uint(userId), client) logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) } @@ -159,34 +167,27 @@ func (h *SdJobHandler) Image(c *gin.Context) { return } - h.service.PushTask(types.SdTask{ + h.sdService.PushTask(types.SdTask{ Id: int(job.Id), Type: types.TaskImage, Params: params, UserId: userId, }) - client := h.service.Clients.Get(uint(job.UserId)) + client := h.sdService.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } // update user's power - tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) - // 记录算力变化日志 - if tx.Error == nil && tx.RowsAffected > 0 { - user, _ := h.GetLoginUser(c) - h.DB.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerConsume, - Amount: job.Power, - Balance: user.Power - job.Power, - Mark: types.PowerSub, - Model: "stable-diffusion", - Remark: fmt.Sprintf("绘图操作,任务ID:%s", job.TaskId), - CreatedAt: time.Now(), - }) + err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerConsume, + Model: "stable-diffusion", + Remark: fmt.Sprintf("绘图操作,任务ID:%s", job.TaskId), + }) + if err != nil { + resp.ERROR(c, err.Error()) + return } resp.SUCCESS(c) @@ -290,25 +291,11 @@ func (h *SdJobHandler) Remove(c *gin.Context) { // 如果任务未完成,或者任务失败,则恢复用户算力 if job.Progress != 100 { - err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - var user model.User - tx.Where("id = ?", job.UserId).First(&user) - err = tx.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerRefund, - Amount: job.Power, - Balance: user.Power, - Mark: types.PowerAdd, - Model: "stable-diffusion", - Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%s, Err: %s", job.TaskId, job.ErrMsg), - CreatedAt: time.Now(), - }).Error + err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: "stable-diffusion", + Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%s, Err: %s", job.TaskId, job.ErrMsg), + }) if err != nil { tx.Rollback() resp.ERROR(c, err.Error()) diff --git a/api/handler/suno_handler.go b/api/handler/suno_handler.go index 721ac4e0..624b703e 100644 --- a/api/handler/suno_handler.go +++ b/api/handler/suno_handler.go @@ -11,6 +11,7 @@ import ( "fmt" "geekai/core" "geekai/core/types" + "geekai/service" "geekai/service/oss" "geekai/service/suno" "geekai/store/model" @@ -26,18 +27,20 @@ import ( type SunoHandler struct { BaseHandler - service *suno.Service - uploader *oss.UploaderManager + sunoService *suno.Service + uploader *oss.UploaderManager + userService *service.UserService } -func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager) *SunoHandler { +func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService) *SunoHandler { return &SunoHandler{ BaseHandler: BaseHandler{ App: app, DB: db, }, - service: service, - uploader: uploader, + sunoService: service, + uploader: uploader, + userService: userService, } } @@ -58,7 +61,7 @@ func (h *SunoHandler) Client(c *gin.Context) { } client := types.NewWsClient(ws) - h.service.Clients.Put(uint(userId), client) + h.sunoService.Clients.Put(uint(userId), client) logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) } @@ -123,7 +126,7 @@ func (h *SunoHandler) Create(c *gin.Context) { } // 创建任务 - h.service.PushTask(types.SunoTask{ + h.sunoService.PushTask(types.SunoTask{ Id: job.Id, UserId: job.UserId, Type: job.Type, @@ -140,24 +143,17 @@ func (h *SunoHandler) Create(c *gin.Context) { }) // update user's power - tx = h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) - // 记录算力变化日志 - if tx.Error == nil && tx.RowsAffected > 0 { - user, _ := h.GetLoginUser(c) - h.DB.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerConsume, - Amount: job.Power, - Balance: user.Power - job.Power, - Mark: types.PowerSub, - Model: job.ModelName, - Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName), - CreatedAt: time.Now(), - }) + err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerConsume, + Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName), + CreatedAt: time.Now(), + }) + if err != nil { + resp.ERROR(c, err.Error()) + return } - client := h.service.Clients.Get(uint(job.UserId)) + client := h.sunoService.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } @@ -239,25 +235,11 @@ func (h *SunoHandler) Remove(c *gin.Context) { // 如果任务未完成,或者任务失败,则恢复用户算力 if job.Progress != 100 { - err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - var user model.User - tx.Where("id = ?", job.UserId).First(&user) - err = tx.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerRefund, - Amount: job.Power, - Balance: user.Power, - Mark: types.PowerAdd, - Model: job.ModelName, - Remark: fmt.Sprintf("Suno 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), - CreatedAt: time.Now(), - }).Error + err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: job.ModelName, + Remark: fmt.Sprintf("Suno 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), + }) if err != nil { tx.Rollback() resp.ERROR(c, err.Error()) diff --git a/api/handler/user_handler.go b/api/handler/user_handler.go index 837fecc3..096da7c5 100644 --- a/api/handler/user_handler.go +++ b/api/handler/user_handler.go @@ -34,6 +34,7 @@ type UserHandler struct { redis *redis.Client licenseService *service.LicenseService captcha *service.CaptchaService + userService *service.UserService } func NewUserHandler( @@ -42,6 +43,7 @@ func NewUserHandler( searcher *xdb.Searcher, client *redis.Client, captcha *service.CaptchaService, + userService *service.UserService, licenseService *service.LicenseService) *UserHandler { return &UserHandler{ BaseHandler: BaseHandler{DB: db, App: app}, @@ -49,6 +51,7 @@ func NewUserHandler( redis: client, captcha: captcha, licenseService: licenseService, + userService: userService, } } @@ -155,10 +158,9 @@ func (h *UserHandler) Register(c *gin.Context) { user.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)) } - res := h.DB.Create(&user) - if res.Error != nil { - resp.ERROR(c, "保存数据失败") - logger.Error(res.Error) + tx := h.DB.Begin() + if err := tx.Create(&user).Error; err != nil { + resp.ERROR(c, err.Error()) return } @@ -167,35 +169,35 @@ func (h *UserHandler) Register(c *gin.Context) { // 增加邀请数量 h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1)) if h.App.SysConfig.InvitePower > 0 { - h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower)) - // 记录邀请算力充值日志 - var inviter model.User - h.DB.Where("id", inviteCode.UserId).First(&inviter) - h.DB.Create(&model.PowerLog{ - UserId: inviter.Id, - Username: inviter.Username, - Type: types.PowerInvite, - Amount: h.App.SysConfig.InvitePower, - Balance: inviter.Power, - Mark: types.PowerAdd, - Model: "", - Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username), - CreatedAt: time.Now(), + err := h.userService.IncreasePower(int(inviteCode.UserId), h.App.SysConfig.InvitePower, model.PowerLog{ + Type: types.PowerInvite, + Model: "", + Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username), }) + if err != nil { + tx.Rollback() + resp.ERROR(c, err.Error()) + return + } } // 添加邀请记录 - h.DB.Create(&model.InviteLog{ + err := tx.Create(&model.InviteLog{ InviterId: inviteCode.UserId, UserId: user.Id, Username: user.Username, InviteCode: inviteCode.Code, Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower), - }) + }).Error + if err != nil { + tx.Rollback() + resp.ERROR(c, err.Error()) + return + } } + tx.Commit() _ = h.redis.Del(c, key) // 注册成功,删除短信验证码 - // 自动登录创建 token token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "user_id": user.Id, diff --git a/api/handler/video_handler.go b/api/handler/video_handler.go index d7e7b5bb..f6d75134 100644 --- a/api/handler/video_handler.go +++ b/api/handler/video_handler.go @@ -11,6 +11,7 @@ import ( "fmt" "geekai/core" "geekai/core/types" + "geekai/service" "geekai/service/oss" "geekai/service/video" "geekai/store/model" @@ -21,23 +22,24 @@ import ( "github.com/gorilla/websocket" "gorm.io/gorm" "net/http" - "time" ) type VideoHandler struct { BaseHandler - service *video.Service - uploader *oss.UploaderManager + videoService *video.Service + uploader *oss.UploaderManager + userService *service.UserService } -func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager) *VideoHandler { +func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService) *VideoHandler { return &VideoHandler{ BaseHandler: BaseHandler{ App: app, DB: db, }, - service: service, - uploader: uploader, + videoService: service, + uploader: uploader, + userService: userService, } } @@ -58,7 +60,7 @@ func (h *VideoHandler) Client(c *gin.Context) { } client := types.NewWsClient(ws) - h.service.Clients.Put(uint(userId), client) + h.videoService.Clients.Put(uint(userId), client) logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) } @@ -102,7 +104,7 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) { } // 创建任务 - h.service.PushTask(types.VideoTask{ + h.videoService.PushTask(types.VideoTask{ Id: job.Id, UserId: userId, Type: types.VideoLuma, @@ -111,24 +113,17 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) { }) // update user's power - tx = h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) - // 记录算力变化日志 - if tx.Error == nil && tx.RowsAffected > 0 { - user, _ := h.GetLoginUser(c) - h.DB.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerConsume, - Amount: job.Power, - Balance: user.Power - job.Power, - Mark: types.PowerSub, - Model: "luma", - Remark: fmt.Sprintf("Luma 文生视频,任务ID:%d", job.Id), - CreatedAt: time.Now(), - }) + err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerConsume, + Model: "luma", + Remark: fmt.Sprintf("Luma 文生视频,任务ID:%d", job.Id), + }) + if err != nil { + resp.ERROR(c, err.Error()) + return } - client := h.service.Clients.Get(uint(job.UserId)) + client := h.videoService.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } @@ -194,25 +189,11 @@ func (h *VideoHandler) Remove(c *gin.Context) { // 如果任务未完成,或者任务失败,则恢复用户算力 if job.Progress != 100 { - err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error - if err != nil { - tx.Rollback() - resp.ERROR(c, err.Error()) - return - } - var user model.User - tx.Where("id = ?", job.UserId).First(&user) - err = tx.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerRefund, - Amount: job.Power, - Balance: user.Power, - Mark: types.PowerAdd, - Model: "luma", - Remark: fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), - CreatedAt: time.Now(), - }).Error + err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{ + Type: types.PowerRefund, + Model: "luma", + Remark: fmt.Sprintf("Luma 任务失败,退回算力。任务ID:%s,Err:%s", job.TaskId, job.ErrMsg), + }) if err != nil { tx.Rollback() resp.ERROR(c, err.Error()) @@ -230,7 +211,14 @@ func (h *VideoHandler) Publish(c *gin.Context) { id := h.GetInt(c, "id", 0) userId := h.GetLoginUserId(c) publish := h.GetBool(c, "publish") - err := h.DB.Model(&model.VideoJob{}).Where("id", id).Where("user_id", userId).UpdateColumn("publish", publish).Error + var job model.VideoJob + err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + err = h.DB.Model(&job).UpdateColumn("publish", publish).Error if err != nil { resp.ERROR(c, err.Error()) return diff --git a/api/main.go b/api/main.go index e3d0ae7a..82cf02f7 100644 --- a/api/main.go +++ b/api/main.go @@ -209,7 +209,7 @@ func main() { s.CheckTaskNotify() s.DownloadFiles() }), - + fx.Provide(service.NewUserService), fx.Provide(payment.NewAlipayService), fx.Provide(payment.NewHuPiPay), fx.Provide(payment.NewJPayService), diff --git a/api/service/dalle/service.go b/api/service/dalle/service.go index 5b1d4ab8..4ea1082e 100644 --- a/api/service/dalle/service.go +++ b/api/service/dalle/service.go @@ -35,9 +35,10 @@ type Service struct { taskQueue *store.RedisQueue notifyQueue *store.RedisQueue Clients *types.LMap[uint, *types.WsClient] // UserId => Client + userService *service.UserService } -func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { +func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service { return &Service{ httpClient: req.C().SetTimeout(time.Minute * 3), db: db, @@ -45,6 +46,7 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli), Clients: types.NewLMap[uint, *types.WsClient](), uploadManager: manager, + userService: userService, } } @@ -122,32 +124,23 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { return "", errors.New("insufficient of power") } - // 更新用户算力 - tx := s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power)) - // 记录算力变化日志 - if tx.Error == nil && tx.RowsAffected > 0 { - var u model.User - s.db.Where("id", user.Id).First(&u) - s.db.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerConsume, - Amount: task.Power, - Balance: u.Power, - Mark: types.PowerSub, - Model: "dall-e-3", - Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)), - CreatedAt: time.Now(), - }) + // 扣减算力 + err := s.userService.DecreasePower(int(user.Id), task.Power, model.PowerLog{ + Type: types.PowerConsume, + Model: "dall-e-3", + Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)), + }) + if err != nil { + return "", fmt.Errorf("error with decrease power: %v", err) } // get image generation API KEY var apiKey model.ApiKey - tx = s.db.Where("type", "dalle"). + err = s.db.Where("type", "dalle"). Where("enabled", true). - Order("last_used_at ASC").First(&apiKey) - if tx.Error != nil { - return "", fmt.Errorf("no available DALL-E api key: %v", tx.Error) + Order("last_used_at ASC").First(&apiKey).Error + if err != nil { + return "", fmt.Errorf("no available DALL-E api key: %v", err) } var res imgRes @@ -181,13 +174,13 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { // update the api key last use time s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) // update task progress - tx = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ + err = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ "progress": 100, "org_url": res.Data[0].Url, "prompt": prompt, - }) - if tx.Error != nil { - return "", fmt.Errorf("err with update database: %v", tx.Error) + }).Error + if err != nil { + return "", fmt.Errorf("err with update database: %v", err) } s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed}) diff --git a/api/service/user_service.go b/api/service/user_service.go new file mode 100644 index 00000000..ea086d01 --- /dev/null +++ b/api/service/user_service.go @@ -0,0 +1,83 @@ +package service + +import ( + "fmt" + "geekai/core/types" + "geekai/store/model" + "gorm.io/gorm" + "sync" + "time" +) + +type UserService struct { + db *gorm.DB + lock sync.Mutex +} + +func NewUserService(db *gorm.DB) *UserService { + return &UserService{db: db, lock: sync.Mutex{}} +} + +// IncreasePower 增加用户算力 +func (s *UserService) IncreasePower(userId int, power int, log model.PowerLog) error { + s.lock.Lock() + defer s.lock.Unlock() + + tx := s.db.Begin() + err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power + ?", power)).Error + if err != nil { + tx.Rollback() + return err + } + var user model.User + tx.Where("id", userId).First(&user) + err = tx.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: log.Type, + Amount: power, + Balance: user.Power, + Mark: types.PowerAdd, + Model: log.Model, + Remark: log.Remark, + CreatedAt: time.Now(), + }).Error + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} + +// DecreasePower 减少用户算力 +func (s *UserService) DecreasePower(userId int, power int, log model.PowerLog) error { + s.lock.Lock() + defer s.lock.Unlock() + + tx := s.db.Begin() + err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error + if err != nil { + tx.Rollback() + return fmt.Errorf("扣减算力失败:%v", err) + } + var user model.User + tx.Where("id", userId).First(&user) + err = tx.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: log.Type, + Amount: power, + Balance: user.Power, + Mark: types.PowerSub, + Model: log.Model, + Remark: log.Remark, + CreatedAt: time.Now(), + }).Error + if err != nil { + tx.Rollback() + return fmt.Errorf("记录算力日志失败:%v", err) + } + tx.Commit() + return nil +} diff --git a/api/service/xxl_job_service.go b/api/service/xxl_job_service.go index 2adecf1b..ef701730 100644 --- a/api/service/xxl_job_service.go +++ b/api/service/xxl_job_service.go @@ -81,54 +81,6 @@ func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (ms // 自动将 VIP 会员的算力补充到每月赠送的最大值 func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (msg string) { logger.Info("开始进行月底账号盘点...") - var users []model.User - res := e.db.Where("vip", 1).Where("status", 1).Find(&users) - if res.Error != nil { - return "No vip users found" - } - - var sysConfig model.Config - res = e.db.Where("marker", "system").First(&sysConfig) - if res.Error != nil { - return "error with get system config: " + res.Error.Error() - } - - var config types.SystemConfig - err := utils.JsonDecode(sysConfig.Config, &config) - if err != nil { - return "error with decode system config: " + err.Error() - } - - for _, u := range users { - // 处理过期的 VIP - if u.ExpiredTime > 0 && u.ExpiredTime <= time.Now().Unix() { - u.Vip = false - e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("vip", false) - continue - } - if u.Power < config.VipMonthPower { - power := config.VipMonthPower - u.Power - // update user - tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", power)) - // 记录算力变动日志 - if tx.Error == nil { - var user model.User - e.db.Where("id", u.Id).First(&user) - e.db.Create(&model.PowerLog{ - UserId: u.Id, - Username: u.Username, - Type: types.PowerRecharge, - Amount: power, - Mark: types.PowerAdd, - Balance: user.Power, - Model: "系统盘点", - Remark: fmt.Sprintf("VIP会员每月算力派发,:%d", config.VipMonthPower), - CreatedAt: time.Now(), - }) - } - } - } - logger.Info("月底盘点完成!") return "success" } diff --git a/deploy/docker-compose.yaml b/deploy/docker-compose.yaml index 0030482a..03fa5a73 100644 --- a/deploy/docker-compose.yaml +++ b/deploy/docker-compose.yaml @@ -27,17 +27,17 @@ services: ports: - "6380:6379" - xxl-job-admin: - container_name: geekai-xxl-job-admin - image: registry.cn-shenzhen.aliyuncs.com/geekmaster/xxl-job-admin:2.4.0 - restart: always - ports: - - "8081:8080" - environment: - - PARAMS=--spring.config.location=/application.properties - volumes: - - ./logs/xxl-job:/data/applogs - - ./conf/xxl-job/application.properties:/application.properties +# xxl-job-admin: +# container_name: geekai-xxl-job-admin +# image: registry.cn-shenzhen.aliyuncs.com/geekmaster/xxl-job-admin:2.4.0 +# restart: always +# ports: +# - "8081:8080" +# environment: +# - PARAMS=--spring.config.location=/application.properties +# volumes: +# - ./logs/xxl-job:/data/applogs +# - ./conf/xxl-job/application.properties:/application.properties tika: image: registry.cn-shenzhen.aliyuncs.com/geekmaster/tika:latest @@ -46,14 +46,14 @@ services: ports: - "9998:9998" - midjourney-proxy: - image: registry.cn-shenzhen.aliyuncs.com/geekmaster/midjourney-proxy:2.6.2 - container_name: geekai-midjourney-proxy - restart: always - ports: - - "8082:8080" - volumes: - - ./conf/mj-proxy:/home/spring/config +# midjourney-proxy: +# image: registry.cn-shenzhen.aliyuncs.com/geekmaster/midjourney-proxy:2.6.2 +# container_name: geekai-midjourney-proxy +# restart: always +# ports: +# - "8082:8080" +# volumes: +# - ./conf/mj-proxy:/home/spring/config # 后端 API 程序 diff --git a/web/src/views/admin/Users.vue b/web/src/views/admin/Users.vue index 73ab35d6..18fe1f55 100644 --- a/web/src/views/admin/Users.vue +++ b/web/src/views/admin/Users.vue @@ -11,7 +11,8 @@ - + +