merge v4.1.3

This commit is contained in:
RockYang
2024-12-16 10:07:52 +08:00
134 changed files with 4804 additions and 1583 deletions

View File

@@ -201,7 +201,6 @@ func needLogin(c *gin.Context) bool {
c.Request.URL.Path == "/api/admin/logout" ||
c.Request.URL.Path == "/api/admin/login/captcha" ||
c.Request.URL.Path == "/api/user/register" ||
c.Request.URL.Path == "/api/user/session" ||
c.Request.URL.Path == "/api/chat/history" ||
c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/chat/list" ||
@@ -227,6 +226,8 @@ func needLogin(c *gin.Context) bool {
c.Request.URL.Path == "/api/suno/client" ||
c.Request.URL.Path == "/api/suno/detail" ||
c.Request.URL.Path == "/api/suno/play" ||
c.Request.URL.Path == "/api/download" ||
c.Request.URL.Path == "/api/video/client" ||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
@@ -367,6 +368,7 @@ func staticResourceMiddleware() gin.HandlerFunc {
// 直接输出图像数据流
c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
c.Abort() // 中断请求
}
c.Next()
}

View File

@@ -57,6 +57,7 @@ type ChatSession struct {
ClientIP string `json:"client_ip"` // 客户端 IP
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
Model ChatModel `json:"model"` // GPT 模型
Tools string `json:"tools"` // 函数
}
type ChatModel struct {

View File

@@ -131,10 +131,10 @@ func (c RedisConfig) Url() string {
}
type SystemConfig struct {
Title string `json:"title,omitempty"` // 网站标题
Slogan string `json:"slogan,omitempty"` // 网站 slogan
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
Logo string `json:"logo,omitempty"`
Title string `json:"title,omitempty"` // 网站标题
Slogan string `json:"slogan,omitempty"` // 网站 slogan
AdminTitle string `json:"admin_title,omitempty"` // 管理后台标题
Logo string `json:"logo,omitempty"` // 方形 Logo
InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值
DailyPower int `json:"daily_power,omitempty"` // 每日赠送算力
InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值
@@ -150,8 +150,9 @@ type SystemConfig struct {
MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力
MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力
DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力
DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
@@ -165,4 +166,6 @@ type SystemConfig struct {
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单
Copyright string `json:"copyright"` // 版权信息
MarkMapText string `json:"mark_map_text"` // 思维导入的默认文本
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
}

View File

@@ -85,13 +85,41 @@ type SunoTask struct {
Channel string `json:"channel"`
UserId int `json:"user_id"`
Type int `json:"type"`
TaskId string `json:"task_id"`
Title string `json:"title"`
RefTaskId string `json:"ref_task_id"`
RefSongId string `json:"ref_song_id"`
RefTaskId string `json:"ref_task_id,omitempty"`
RefSongId string `json:"ref_song_id,omitempty"`
Prompt string `json:"prompt"` // 提示词/歌词
Tags string `json:"tags"`
Model string `json:"model"`
Instrumental bool `json:"instrumental"` // 是否纯音乐
ExtendSecs int `json:"extend_secs"` // 延长秒杀
Instrumental bool `json:"instrumental"` // 是否纯音乐
ExtendSecs int `json:"extend_secs,omitempty"` // 延长秒杀
SongId string `json:"song_id,omitempty"` // 合并歌曲ID
AudioURL string `json:"audio_url"` // 用户上传音频地址
}
const (
VideoLuma = "luma"
VideoRunway = "runway"
VideoCog = "cog"
)
type VideoTask struct {
Id uint `json:"id"`
Channel string `json:"channel"`
UserId int `json:"user_id"`
Type string `json:"type"`
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
Params VideoParams `json:"params"`
}
type VideoParams struct {
PromptOptimize bool `json:"prompt_optimize"` // 是否优化提示词
Loop bool `json:"loop"` // 是否循环参考图
StartImgURL string `json:"start_img_url"` // 第一帧参考图地址
EndImgURL string `json:"end_img_url"` // 最后一帧参考图地址
Model string `json:"model"` // 使用哪个模型生成视频
Radio string `json:"radio"` // 视频尺寸
Style string `json:"style"` // 风格
Duration int `json:"duration"` // 视频时长(秒)
}

View File

@@ -14,6 +14,7 @@ import (
"geekai/core/types"
"geekai/handler"
logger2 "geekai/logger"
"geekai/service"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
@@ -28,33 +29,49 @@ import (
var logger = logger2.GetLogger()
// Manager 管理员
type Manager struct {
Username string `json:"username"`
Password string `json:"password"`
Captcha string `json:"captcha"` // 验证码
CaptchaId string `json:"captcha_id"` // 验证码id
}
const SuperManagerID = 1
type ManagerHandler struct {
handler.BaseHandler
redis *redis.Client
redis *redis.Client
captcha *service.CaptchaService
}
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler {
return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client}
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client, captcha *service.CaptchaService) *ManagerHandler {
return &ManagerHandler{
BaseHandler: handler.BaseHandler{DB: db, App: app},
redis: client,
captcha: captcha,
}
}
// Login 登录
func (h *ManagerHandler) Login(c *gin.Context) {
var data Manager
var data struct {
Username string `json:"username"`
Password string `json:"password"`
Key string `json:"key,omitempty"`
Dots string `json:"dots,omitempty"`
X int `json:"x,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if h.App.SysConfig.EnabledVerify {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
}
var manager model.AdminUser
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
if res.Error != nil {

View File

@@ -49,7 +49,7 @@ func (h *UserHandler) List(c *gin.Context) {
}
session.Model(&model.User{}).Count(&total)
res := session.Offset(offset).Limit(pageSize).Find(&items)
res := session.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
if res.Error == nil {
for _, item := range items {
var user vo.User
@@ -204,33 +204,69 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
}
func (h *UserHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
if id <= 0 {
id := c.Query("id")
ids := c.QueryArray("ids[]")
if id != "" {
ids = append(ids, id)
}
if len(ids) == 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
// 删除用户
res := h.DB.Where("id = ?", id).Delete(&model.User{})
if res.Error != nil {
tx := h.DB.Begin()
var err error
for _, id = range ids {
// 删除用户
if err = tx.Where("id", id).Delete(&model.User{}).Error; err != nil {
break
}
// 删除聊天记录
if err = tx.Unscoped().Where("user_id = ?", id).Delete(&model.ChatItem{}).Error; err != nil {
break
}
// 删除聊天历史记录
if err = tx.Unscoped().Where("user_id = ?", id).Delete(&model.ChatMessage{}).Error; err != nil {
break
}
// 删除登录日志
if err = tx.Where("user_id = ?", id).Delete(&model.UserLoginLog{}).Error; err != nil {
break
}
// 删除算力日志
if err = tx.Where("user_id = ?", id).Delete(&model.PowerLog{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.InviteLog{}).Error; err != nil {
break
}
// 删除众筹日志
if err = tx.Where("user_id = ?", id).Delete(&model.Redeem{}).Error; err != nil {
break
}
// 删除绘图任务
if err = tx.Where("user_id = ?", id).Delete(&model.MidJourneyJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.SdJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.DallJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.SunoJob{}).Error; err != nil {
break
}
if err = tx.Where("user_id = ?", id).Delete(&model.VideoJob{}).Error; err != nil {
break
}
}
if err != nil {
resp.ERROR(c, "删除失败")
tx.Rollback()
return
}
// 删除聊天记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{})
// 删除聊天历史记录
h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{})
// 删除登录日志
h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
// 删除算力日志
h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{})
// 删除众筹日志
h.DB.Where("user_id = ?", id).Delete(&model.Redeem{})
// 删除绘图任务
h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{})
h.DB.Where("user_id = ?", id).Delete(&model.SdJob{})
// 删除订单
h.DB.Where("user_id = ?", id).Delete(&model.Order{})
tx.Commit()
resp.SUCCESS(c)
}

View File

@@ -46,9 +46,10 @@ type ChatHandler struct {
licenseService *service.LicenseService
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
userService *service.UserService
}
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ChatHandler {
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler {
return &ChatHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
redis: redis,
@@ -56,6 +57,7 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manag
licenseService: licenseService,
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
ChatContexts: types.NewLMap[string, []types.Message](),
userService: userService,
}
}
@@ -71,6 +73,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
roleId := h.GetInt(c, "role_id", 0)
chatId := c.Query("chat_id")
modelId := h.GetInt(c, "model_id", 0)
tools := c.Query("tools")
client := types.NewWsClient(ws)
var chatRole model.ChatRole
@@ -97,6 +100,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
SessionId: sessionId,
ClientIP: c.ClientIP(),
UserId: h.GetLoginUserId(c),
Tools: tools,
}
// use old chat data override the chat model and role ID
@@ -209,34 +213,37 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
var items []model.Function
res = h.DB.Where("enabled", true).Find(&items)
if res.Error == nil {
var tools = make([]types.Tool, 0)
for _, v := range items {
var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters)
if err != nil {
continue
}
tool := types.Tool{
Type: "function",
Function: types.Function{
Name: v.Name,
Description: v.Description,
Parameters: parameters,
},
}
if v, ok := parameters["required"]; v == nil || !ok {
tool.Function.Parameters["required"] = []string{}
}
tools = append(tools, tool)
}
if len(tools) > 0 {
req.Tools = tools
req.ToolChoice = "auto"
if session.Tools != "" {
toolIds := strings.Split(session.Tools, ",")
var items []model.Function
res = h.DB.Where("enabled", true).Where("id IN ?", toolIds).Find(&items)
if res.Error == nil {
var tools = make([]types.Tool, 0)
for _, v := range items {
var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters)
if err != nil {
continue
}
tool := types.Tool{
Type: "function",
Function: types.Function{
Name: v.Name,
Description: v.Description,
Parameters: parameters,
},
}
if v, ok := parameters["required"]; v == nil || !ok {
tool.Function.Parameters["required"] = []string{}
}
tools = append(tools, tool)
}
if len(tools) > 0 {
req.Tools = tools
req.ToolChoice = "auto"
}
}
}
@@ -270,7 +277,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks + promptTokens
for _, v := range messages {
for i := len(messages) - 1; i >= 0; i-- {
v := messages[i]
tks, _ := utils.CalcTokens(v.Content, req.Model)
// 上下文 token 超出了模型的最大上下文长度
if tokens+tks >= session.Model.MaxContext {
@@ -481,24 +489,15 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
if session.Model.Power > 0 {
power = session.Model.Power
}
res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
if res.Error == nil {
// 记录算力消费日志
var u model.User
h.DB.Where("id", userVo.Id).First(&u)
h.DB.Create(&model.PowerLog{
UserId: userVo.Id,
Username: userVo.Username,
Type: types.PowerConsume,
Amount: power,
Mark: types.PowerSub,
Balance: u.Power,
Model: session.Model.Value,
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d回复长度%d", session.Model.Name, promptTokens, replyTokens),
CreatedAt: time.Now(),
})
}
err := h.userService.DecreasePower(int(userVo.Id), power, model.PowerLog{
Type: types.PowerConsume,
Model: session.Model.Value,
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d回复长度%d", session.Model.Name, promptTokens, replyTokens),
})
if err != nil {
logger.Error(err)
}
}
func (h *ChatHandler) saveChatHistory(
@@ -544,9 +543,9 @@ func (h *ChatHandler) saveChatHistory(
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
err = h.DB.Save(&historyUserMsg).Error
if err != nil {
logger.Error("failed to save prompt history message: ", err)
}
// for reply
@@ -566,9 +565,9 @@ func (h *ChatHandler) saveChatHistory(
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
err = h.DB.Create(&historyReplyMsg).Error
if err != nil {
logger.Error("failed to save reply history message: ", err)
}
// 更新用户算力
@@ -577,8 +576,8 @@ func (h *ChatHandler) saveChatHistory(
}
// 保存当前会话
var chatItem model.ChatItem
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
err = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem).Error
if err != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = userVo.Id
chatItem.RoleId = role.Id
@@ -589,7 +588,10 @@ func (h *ChatHandler) saveChatHistory(
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.DB.Create(&chatItem)
err = h.DB.Create(&chatItem).Error
if err != nil {
logger.Error("failed to save chat item: ", err)
}
}
}

View File

@@ -11,32 +11,33 @@ import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/dalle"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gorilla/websocket"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
)
type DallJobHandler struct {
BaseHandler
redis *redis.Client
service *dalle.Service
uploader *oss.UploaderManager
redis *redis.Client
dallService *dalle.Service
uploader *oss.UploaderManager
userService *service.UserService
}
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler {
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager, userService *service.UserService) *DallJobHandler {
return &DallJobHandler{
service: service,
uploader: manager,
dallService: service,
uploader: manager,
userService: userService,
BaseHandler: BaseHandler{
App: app,
DB: db,
@@ -61,14 +62,14 @@ func (h *DallJobHandler) Client(c *gin.Context) {
}
client := types.NewWsClient(ws)
h.service.Clients.Put(uint(userId), client)
h.dallService.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
client.Close()
h.service.Clients.Delete(uint(userId))
h.dallService.Clients.Delete(uint(userId))
return
}
@@ -127,7 +128,7 @@ func (h *DallJobHandler) Image(c *gin.Context) {
return
}
h.service.PushTask(types.DallTask{
h.dallService.PushTask(types.DallTask{
JobId: job.Id,
UserId: uint(userId),
Prompt: data.Prompt,
@@ -137,7 +138,7 @@ func (h *DallJobHandler) Image(c *gin.Context) {
Power: job.Power,
})
client := h.service.Clients.Get(job.UserId)
client := h.dallService.Clients.Get(job.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
@@ -175,7 +176,7 @@ func (h *DallJobHandler) JobList(c *gin.Context) {
}
// JobList 获取任务列表
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) {
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
session := h.DB.Session(&gorm.Session{})
if finish {
@@ -193,11 +194,14 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
}
// 统计总数
var total int64
session.Model(&model.DallJob{}).Count(&total)
var items []model.DallJob
res := session.Find(&items)
if res.Error != nil {
return res.Error, nil
return res.Error, vo.Page{}
}
var jobs = make([]vo.DallJob, 0)
@@ -210,7 +214,7 @@ func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize in
jobs = append(jobs, job)
}
return nil, jobs
return nil, vo.NewPage(total, page, pageSize, jobs)
}
// Remove remove task image
@@ -233,26 +237,11 @@ func (h *DallJobHandler) Remove(c *gin.Context) {
// 如果任务未完成,或者任务失败,则恢复用户算力
if job.Progress != 100 {
err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
var user model.User
tx.Where("id = ?", job.UserId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerRefund,
Amount: job.Power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: "dall-e-3",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg),
CreatedAt: time.Now(),
}).Error
err := h.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "dall-e-3",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())

View File

@@ -8,15 +8,16 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service/dalle"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"errors"
"fmt"
"strings"
"time"
@@ -224,3 +225,27 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
resp.SUCCESS(c, content)
}
// List 获取所有的工具函数列表
func (h *FunctionHandler) List(c *gin.Context) {
var items []model.Function
err := h.DB.Where("enabled", true).Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
tools := make([]vo.Function, 0)
for _, v := range items {
var f vo.Function
err = utils.CopyObject(v, &f)
if err != nil {
continue
}
f.Action = ""
f.Token = ""
tools = append(tools, f)
}
resp.SUCCESS(c, tools)
}

View File

@@ -9,7 +9,6 @@ package handler
import (
"geekai/core"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
@@ -59,23 +58,16 @@ func (h *InviteHandler) Code(c *gin.Context) {
// List Log 用户邀请记录
func (h *InviteHandler) List(c *gin.Context) {
var data struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
userId := h.GetLoginUserId(c)
session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
var total int64
session.Model(&model.InviteLog{}).Count(&total)
var items []model.InviteLog
var list = make([]vo.InviteLog, 0)
offset := (data.Page - 1) * data.PageSize
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
offset := (page - 1) * pageSize
res := session.Order("id DESC").Offset(offset).Limit(pageSize).Find(&items)
if res.Error == nil {
for _, item := range items {
var v vo.InviteLog
@@ -89,7 +81,7 @@ func (h *InviteHandler) List(c *gin.Context) {
}
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, list))
}
// Hits 访问邀请码

View File

@@ -15,6 +15,7 @@ import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"github.com/gin-gonic/gin"
@@ -30,13 +31,15 @@ import (
// MarkMapHandler 生成思维导图
type MarkMapHandler struct {
BaseHandler
clients *types.LMap[int, *types.WsClient]
clients *types.LMap[int, *types.WsClient]
userService *service.UserService
}
func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler {
func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *MarkMapHandler {
return &MarkMapHandler{
BaseHandler: BaseHandler{App: app, DB: db},
clients: types.NewLMap[int, *types.WsClient](),
userService: userService,
}
}
@@ -185,22 +188,13 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
// 扣减算力
if chatModel.Power > 0 {
res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power))
if res.Error == nil {
// 记录算力消费日志
var u model.User
h.DB.Where("id", userId).First(&u)
h.DB.Create(&model.PowerLog{
UserId: u.Id,
Username: u.Username,
Type: types.PowerConsume,
Amount: chatModel.Power,
Mark: types.PowerSub,
Balance: u.Power,
Model: chatModel.Value,
Remark: fmt.Sprintf("AI绘制思维导图模型名称%s, ", chatModel.Value),
CreatedAt: time.Now(),
})
err = h.userService.DecreasePower(userId, chatModel.Power, model.PowerLog{
Type: types.PowerConsume,
Model: chatModel.Value,
Remark: fmt.Sprintf("AI绘制思维导图模型名称%s, ", chatModel.Value),
})
if err != nil {
return err
}
}

View File

@@ -30,16 +30,18 @@ import (
type MidJourneyHandler struct {
BaseHandler
service *mj.Service
snowflake *service.Snowflake
uploader *oss.UploaderManager
mjService *mj.Service
snowflake *service.Snowflake
uploader *oss.UploaderManager
userService *service.UserService
}
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager) *MidJourneyHandler {
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager, userService *service.UserService) *MidJourneyHandler {
return &MidJourneyHandler{
snowflake: snowflake,
service: service,
uploader: manager,
snowflake: snowflake,
mjService: service,
uploader: manager,
userService: userService,
BaseHandler: BaseHandler{
App: app,
DB: db,
@@ -80,7 +82,7 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
}
client := types.NewWsClient(ws)
h.service.Clients.Put(uint(userId), client)
h.mjService.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
@@ -196,7 +198,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return
}
h.service.PushTask(types.MjTask{
h.mjService.PushTask(types.MjTask{
Id: job.Id,
TaskId: taskId,
Type: types.TaskType(data.TaskType),
@@ -208,28 +210,22 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
Mode: h.App.SysConfig.MjMode,
})
client := h.service.Clients.Get(uint(job.UserId))
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("%s操作任务ID%s", opt, job.TaskId),
CreatedAt: time.Now(),
})
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "mid-journey",
Remark: fmt.Sprintf("%s操作任务ID%s", opt, job.TaskId),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
@@ -269,7 +265,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
return
}
h.service.PushTask(types.MjTask{
h.mjService.PushTask(types.MjTask{
Id: job.Id,
Type: types.TaskUpscale,
UserId: userId,
@@ -280,27 +276,22 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
Mode: h.App.SysConfig.MjMode,
})
client := h.service.Clients.Get(uint(job.UserId))
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Upscale 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "mid-journey",
Remark: fmt.Sprintf("Upscale 操作任务ID%s", job.TaskId),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
@@ -334,7 +325,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
return
}
h.service.PushTask(types.MjTask{
h.mjService.PushTask(types.MjTask{
Id: job.Id,
Type: types.TaskVariation,
UserId: userId,
@@ -345,28 +336,21 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
Mode: h.App.SysConfig.MjMode,
})
client := h.service.Clients.Get(uint(job.UserId))
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Variation 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "mid-journey",
Remark: fmt.Sprintf("Variation 操作任务ID%s", job.TaskId),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
@@ -401,7 +385,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
}
// JobList 获取 MJ 任务列表
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) {
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress >= ?", 100).Order("id DESC")
@@ -419,10 +403,14 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
session = session.Offset(offset).Limit(pageSize)
}
// 统计总数
var total int64
session.Model(&model.MidJourneyJob{}).Count(&total)
var items []model.MidJourneyJob
res := session.Find(&items)
if res.Error != nil {
return res.Error, nil
return res.Error, vo.Page{}
}
var jobs = make([]vo.MidJourneyJob, 0)
@@ -442,7 +430,7 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
jobs = append(jobs, job)
}
return nil, jobs
return nil, vo.NewPage(total, page, pageSize, jobs)
}
// Remove remove task image
@@ -465,25 +453,11 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
// 如果任务未完成,或者任务失败,则恢复用户算力
if job.Progress != 100 {
err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
var user model.User
tx.Where("id = ?", job.UserId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerRefund,
Amount: job.Power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: "mid-journey",
Remark: fmt.Sprintf("绘画任务失败退回算力。任务ID%sErr: %s", job.TaskId, job.ErrMsg),
CreatedAt: time.Now(),
}).Error
err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "mid-journey",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
@@ -498,7 +472,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
logger.Error("remove image failed: ", err)
}
client := h.service.Clients.Get(uint(job.UserId))
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}

View File

@@ -11,6 +11,7 @@ import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
@@ -21,11 +22,12 @@ import (
type RedeemHandler struct {
BaseHandler
lock sync.Mutex
lock sync.Mutex
userService *service.UserService
}
func NewRedeemHandler(app *core.AppServer, db *gorm.DB) *RedeemHandler {
return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}}
func NewRedeemHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService) *RedeemHandler {
return &RedeemHandler{BaseHandler: BaseHandler{App: app, DB: db}, userService: userService}
}
func (h *RedeemHandler) Verify(c *gin.Context) {
@@ -59,7 +61,11 @@ func (h *RedeemHandler) Verify(c *gin.Context) {
}
tx := h.DB.Begin()
err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power + ?", item.Power)).Error
err := h.userService.IncreasePower(int(userId), item.Power, model.PowerLog{
Type: types.PowerRedeem,
Model: "兑换码",
Remark: fmt.Sprintf("兑换码核销,算力:%d兑换码%s...", item.Power, item.Code[:10]),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
@@ -76,26 +82,6 @@ func (h *RedeemHandler) Verify(c *gin.Context) {
return
}
// 记录算力充值日志
var user model.User
err = tx.Where("id", userId).First(&user).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
h.DB.Create(&model.PowerLog{
UserId: userId,
Username: user.Username,
Type: types.PowerRedeem,
Amount: item.Power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: "兑换码",
Remark: fmt.Sprintf("兑换码核销,算力:%d兑换码%s...", item.Power, item.Code[:10]),
CreatedAt: time.Now(),
})
tx.Commit()
resp.SUCCESS(c)

View File

@@ -31,19 +31,27 @@ import (
type SdJobHandler struct {
BaseHandler
redis *redis.Client
service *sd.Service
uploader *oss.UploaderManager
snowflake *service.Snowflake
leveldb *store.LevelDB
redis *redis.Client
sdService *sd.Service
uploader *oss.UploaderManager
snowflake *service.Snowflake
leveldb *store.LevelDB
userService *service.UserService
}
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, service *sd.Service, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
func NewSdJobHandler(app *core.AppServer,
db *gorm.DB,
service *sd.Service,
manager *oss.UploaderManager,
snowflake *service.Snowflake,
userService *service.UserService,
levelDB *store.LevelDB) *SdJobHandler {
return &SdJobHandler{
service: service,
uploader: manager,
snowflake: snowflake,
leveldb: levelDB,
sdService: service,
uploader: manager,
snowflake: snowflake,
leveldb: levelDB,
userService: userService,
BaseHandler: BaseHandler{
App: app,
DB: db,
@@ -68,7 +76,7 @@ func (h *SdJobHandler) Client(c *gin.Context) {
}
client := types.NewWsClient(ws)
h.service.Clients.Put(uint(userId), client)
h.sdService.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
@@ -159,34 +167,27 @@ func (h *SdJobHandler) Image(c *gin.Context) {
return
}
h.service.PushTask(types.SdTask{
h.sdService.PushTask(types.SdTask{
Id: int(job.Id),
Type: types.TaskImage,
Params: params,
UserId: userId,
})
client := h.service.Clients.Get(uint(job.UserId))
client := h.sdService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "stable-diffusion",
Remark: fmt.Sprintf("绘图操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "stable-diffusion",
Remark: fmt.Sprintf("绘图操作任务ID%s", job.TaskId),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
@@ -223,7 +224,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
}
// JobList 获取 MJ 任务列表
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) {
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, vo.Page) {
session := h.DB.Session(&gorm.Session{})
if finish {
@@ -242,10 +243,14 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
session = session.Offset(offset).Limit(pageSize)
}
// 统计总数
var total int64
session.Model(&model.SdJob{}).Count(&total)
var items []model.SdJob
res := session.Find(&items)
if res.Error != nil {
return res.Error, nil
return res.Error, vo.Page{}
}
var jobs = make([]vo.SdJob, 0)
@@ -267,7 +272,7 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
jobs = append(jobs, job)
}
return nil, jobs
return nil, vo.NewPage(total, page, pageSize, jobs)
}
// Remove remove task image
@@ -290,25 +295,11 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
// 如果任务未完成,或者任务失败,则恢复用户算力
if job.Progress != 100 {
err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
var user model.User
tx.Where("id = ?", job.UserId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerRefund,
Amount: job.Power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: "stable-diffusion",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%s Err: %s", job.TaskId, job.ErrMsg),
CreatedAt: time.Now(),
}).Error
err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "stable-diffusion",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%s Err: %s", job.TaskId, job.ErrMsg),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())

View File

@@ -56,15 +56,17 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "验证码错误,请先完人机验证")
return
if h.App.SysConfig.EnabledVerify {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
}
code := utils.RandomNumber(6)

View File

@@ -11,6 +11,7 @@ import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/oss"
"geekai/service/suno"
"geekai/store/model"
@@ -26,18 +27,20 @@ import (
type SunoHandler struct {
BaseHandler
service *suno.Service
uploader *oss.UploaderManager
sunoService *suno.Service
uploader *oss.UploaderManager
userService *service.UserService
}
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager) *SunoHandler {
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager, userService *service.UserService) *SunoHandler {
return &SunoHandler{
BaseHandler: BaseHandler{
App: app,
DB: db,
},
service: service,
uploader: uploader,
sunoService: service,
uploader: uploader,
userService: userService,
}
}
@@ -58,7 +61,7 @@ func (h *SunoHandler) Client(c *gin.Context) {
}
client := types.NewWsClient(ws)
h.service.Clients.Put(uint(userId), client)
h.sunoService.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
@@ -72,15 +75,32 @@ func (h *SunoHandler) Create(c *gin.Context) {
Tags string `json:"tags"`
Title string `json:"title"`
Type int `json:"type"`
RefTaskId string `json:"ref_task_id"` // 续写的任务id
ExtendSecs int `json:"extend_secs"` // 续写秒数
RefSongId string `json:"ref_song_id"` // 续写的歌曲id
RefTaskId string `json:"ref_task_id"` // 续写的任务id
ExtendSecs int `json:"extend_secs"` // 续写秒数
RefSongId string `json:"ref_song_id"` // 续写的歌曲id
SongId string `json:"song_id,omitempty"` // 要拼接的歌曲id
AudioURL string `json:"audio_url,omitempty"` // 上传自己创作的歌曲
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 歌曲拼接
if data.SongId != "" && data.Type == 3 {
var song model.SunoJob
if err := h.DB.Where("song_id = ?", data.SongId).First(&song).Error; err == nil {
data.Instrumental = song.Instrumental
data.Model = song.ModelName
data.Tags = song.Tags
}
// 拼接歌词
var refSong model.SunoJob
if err := h.DB.Where("song_id = ?", data.RefSongId).First(&refSong).Error; err == nil {
data.Prompt = fmt.Sprintf("%s\n%s", song.Prompt, refSong.Prompt)
}
}
// 插入数据库
job := model.SunoJob{
UserId: int(h.GetLoginUserId(c)),
@@ -106,7 +126,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
}
// 创建任务
h.service.PushTask(types.SunoTask{
h.sunoService.PushTask(types.SunoTask{
Id: job.Id,
UserId: job.UserId,
Type: job.Type,
@@ -118,27 +138,22 @@ func (h *SunoHandler) Create(c *gin.Context) {
Tags: data.Tags,
Model: data.Model,
Instrumental: data.Instrumental,
SongId: data.SongId,
AudioURL: data.AudioURL,
})
// update user's power
tx = h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: job.ModelName,
Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName),
CreatedAt: time.Now(),
})
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName),
CreatedAt: time.Now(),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
client := h.service.Clients.Get(uint(job.UserId))
client := h.sunoService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
@@ -147,8 +162,8 @@ func (h *SunoHandler) Create(c *gin.Context) {
func (h *SunoHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
// 统计总数
@@ -220,25 +235,11 @@ func (h *SunoHandler) Remove(c *gin.Context) {
// 如果任务未完成,或者任务失败,则恢复用户算力
if job.Progress != 100 {
err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
var user model.User
tx.Where("id = ?", job.UserId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerRefund,
Amount: job.Power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: job.ModelName,
Remark: fmt.Sprintf("Suno 任务失败退回算力。任务ID%sErr:%s", job.TaskId, job.ErrMsg),
CreatedAt: time.Now(),
}).Error
err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: job.ModelName,
Remark: fmt.Sprintf("Suno 任务失败退回算力。任务ID%sErr:%s", job.TaskId, job.ErrMsg),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())

View File

@@ -17,19 +17,21 @@ import (
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"io"
"net/http"
"time"
)
type UploadHandler struct {
type NetHandler struct {
BaseHandler
uploaderManager *oss.UploaderManager
}
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
func NewNetHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *NetHandler {
return &NetHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
}
func (h *UploadHandler) Upload(c *gin.Context) {
func (h *NetHandler) Upload(c *gin.Context) {
file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
if err != nil {
resp.ERROR(c, err.Error())
@@ -60,7 +62,7 @@ func (h *UploadHandler) Upload(c *gin.Context) {
resp.SUCCESS(c, file)
}
func (h *UploadHandler) List(c *gin.Context) {
func (h *NetHandler) List(c *gin.Context) {
var data struct {
Urls []string `json:"urls,omitempty"`
}
@@ -95,7 +97,7 @@ func (h *UploadHandler) List(c *gin.Context) {
}
// Remove remove files
func (h *UploadHandler) Remove(c *gin.Context) {
func (h *NetHandler) Remove(c *gin.Context) {
userId := h.GetLoginUserId(c)
id := h.GetInt(c, "id", 0)
var file model.File
@@ -119,3 +121,28 @@ func (h *UploadHandler) Remove(c *gin.Context) {
_ = h.uploaderManager.GetUploadHandler().Delete(objectKey)
resp.SUCCESS(c)
}
func (h *NetHandler) Download(c *gin.Context) {
fileUrl := c.Query("url")
// 使用http工具下载文件
if fileUrl == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
// 使用http.Get下载文件
r, err := http.Get(fileUrl)
if err != nil {
resp.ERROR(c, err.Error())
return
}
defer r.Body.Close()
if r.StatusCode != http.StatusOK {
resp.ERROR(c, "error status"+r.Status)
return
}
c.Status(http.StatusOK)
// 将下载的文件内容写入响应
_, _ = io.Copy(c.Writer, r.Body)
}

View File

@@ -33,6 +33,8 @@ type UserHandler struct {
searcher *xdb.Searcher
redis *redis.Client
licenseService *service.LicenseService
captcha *service.CaptchaService
userService *service.UserService
}
func NewUserHandler(
@@ -40,12 +42,16 @@ func NewUserHandler(
db *gorm.DB,
searcher *xdb.Searcher,
client *redis.Client,
captcha *service.CaptchaService,
userService *service.UserService,
licenseService *service.LicenseService) *UserHandler {
return &UserHandler{
BaseHandler: BaseHandler{DB: db, App: app},
searcher: searcher,
redis: client,
captcha: captcha,
licenseService: licenseService,
userService: userService,
}
}
@@ -55,9 +61,14 @@ func (h *UserHandler) Register(c *gin.Context) {
var data struct {
RegWay string `json:"reg_way"`
Username string `json:"username"`
Mobile string `json:"mobile"`
Email string `json:"email"`
Password string `json:"password"`
Code string `json:"code"`
InviteCode string `json:"invite_code"`
Key string `json:"key,omitempty"`
Dots string `json:"dots,omitempty"`
X int `json:"x,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -79,8 +90,15 @@ func (h *UserHandler) Register(c *gin.Context) {
// 检查验证码
var key string
if data.RegWay == "email" || data.RegWay == "mobile" {
key = CodeStorePrefix + data.Username
if data.RegWay == "email" {
key = CodeStorePrefix + data.Email
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误")
return
}
} else if data.RegWay == "mobile" {
key = CodeStorePrefix + data.Mobile
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误")
@@ -100,7 +118,17 @@ func (h *UserHandler) Register(c *gin.Context) {
// check if the username is existing
var item model.User
res := h.DB.Where("username = ?", data.Username).First(&item)
session := h.DB.Session(&gorm.Session{})
if data.Mobile != "" {
session = session.Where("mobile = ?", data.Mobile)
data.Username = data.Mobile
} else if data.Email != "" {
session = session.Where("email = ?", data.Email)
data.Username = data.Email
} else if data.Username != "" {
session = session.Where("username = ?", data.Username)
}
session.First(&item)
if item.Id > 0 {
resp.ERROR(c, "该用户名已经被注册")
return
@@ -109,6 +137,8 @@ func (h *UserHandler) Register(c *gin.Context) {
salt := utils.RandString(8)
user := model.User{
Username: data.Username,
Mobile: data.Mobile,
Email: data.Email,
Password: utils.GenPassword(data.Password, salt),
Avatar: "/images/avatar/user.png",
Salt: salt,
@@ -128,10 +158,9 @@ func (h *UserHandler) Register(c *gin.Context) {
user.Nickname = fmt.Sprintf("极客学长@%d", utils.RandomNumber(6))
}
res = h.DB.Create(&user)
if res.Error != nil {
resp.ERROR(c, "保存数据失败")
logger.Error(res.Error)
tx := h.DB.Begin()
if err := tx.Create(&user).Error; err != nil {
resp.ERROR(c, err.Error())
return
}
@@ -140,35 +169,35 @@ func (h *UserHandler) Register(c *gin.Context) {
// 增加邀请数量
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.InvitePower > 0 {
h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower))
// 记录邀请算力充值日志
var inviter model.User
h.DB.Where("id", inviteCode.UserId).First(&inviter)
h.DB.Create(&model.PowerLog{
UserId: inviter.Id,
Username: inviter.Username,
Type: types.PowerInvite,
Amount: h.App.SysConfig.InvitePower,
Balance: inviter.Power,
Mark: types.PowerAdd,
Model: "",
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
CreatedAt: time.Now(),
err := h.userService.IncreasePower(int(inviteCode.UserId), h.App.SysConfig.InvitePower, model.PowerLog{
Type: types.PowerInvite,
Model: "",
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d邀请码%s新用户%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
// 添加邀请记录
h.DB.Create(&model.InviteLog{
err := tx.Create(&model.InviteLog{
InviterId: inviteCode.UserId,
UserId: user.Id,
Username: user.Username,
InviteCode: inviteCode.Code,
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
})
}).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
// 自动登录创建 token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.Id,
@@ -193,11 +222,28 @@ func (h *UserHandler) Login(c *gin.Context) {
var data struct {
Username string `json:"username"`
Password string `json:"password"`
Key string `json:"key,omitempty"`
Dots string `json:"dots,omitempty"`
X int `json:"x,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if h.App.SysConfig.EnabledVerify {
var check bool
if data.X != 0 {
check = h.captcha.SlideCheck(data)
} else {
check = h.captcha.Check(data)
}
if !check {
resp.ERROR(c, "请先完人机验证")
return
}
}
var user model.User
res := h.DB.Where("username = ?", data.Username).First(&user)
if res.Error != nil {
@@ -285,8 +331,10 @@ func (h *UserHandler) CLogin(c *gin.Context) {
// CLoginCallback 第三方登录回调
func (h *UserHandler) CLoginCallback(c *gin.Context) {
loginType := h.GetTrim(c, "login_type")
code := h.GetTrim(c, "code")
loginType := c.Query("login_type")
code := c.Query("code")
userId := h.GetInt(c, "user_id", 0)
action := c.Query("action")
var res types.BizVo
apiURL := fmt.Sprintf("%s/api/clogin/info", h.App.Config.ApiConfig.ApiURL)
@@ -311,11 +359,34 @@ func (h *UserHandler) CLoginCallback(c *gin.Context) {
// login successfully
data := res.Data.(map[string]interface{})
session := gin.H{}
var user model.User
tx := h.DB.Debug().Where("openid", data["openid"]).First(&user)
if tx.Error != nil { // user not exist, create new user
// 检测最大注册人数
if action == "bind" && userId > 0 {
err = h.DB.Where("openid", data["openid"]).First(&user).Error
if err == nil {
resp.ERROR(c, "该微信已经绑定其他账号,请先解绑")
return
}
err = h.DB.Where("id", userId).First(&user).Error
if err != nil {
resp.ERROR(c, "绑定用户不存在")
return
}
err = h.DB.Model(&user).UpdateColumn("openid", data["openid"]).Error
if err != nil {
resp.ERROR(c, "更新用户信息失败,"+err.Error())
return
}
resp.SUCCESS(c, gin.H{"token": ""})
return
}
session := gin.H{}
tx := h.DB.Where("openid", data["openid"]).First(&user)
if tx.Error != nil {
// create new user
var totalUser int64
h.DB.Model(&model.User{}).Count(&totalUser)
if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum {
@@ -383,18 +454,24 @@ func (h *UserHandler) CLoginCallback(c *gin.Context) {
// Session 获取/验证会话
func (h *UserHandler) Session(c *gin.Context) {
user, err := h.GetLoginUser(c)
if err == nil {
var userVo vo.User
err := utils.CopyObject(user, &userVo)
if err != nil {
resp.ERROR(c)
}
userVo.Id = user.Id
resp.SUCCESS(c, userVo)
} else {
resp.NotAuth(c)
if err != nil {
resp.NotAuth(c, err.Error())
return
}
var userVo vo.User
err = utils.CopyObject(user, &userVo)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 用户 VIP 到期
if user.ExpiredTime > 0 && user.ExpiredTime < time.Now().Unix() {
h.DB.Model(&user).UpdateColumn("vip", false)
}
userVo.Id = user.Id
resp.SUCCESS(c, userVo)
}
type userProfile struct {
@@ -490,10 +567,12 @@ func (h *UserHandler) UpdatePass(c *gin.Context) {
resp.SUCCESS(c)
}
// ResetPass 重置密码
// ResetPass 找回密码
func (h *UserHandler) ResetPass(c *gin.Context) {
var data struct {
Username string `json:"username"`
Type string `json:"type"` // 验证类别mobile, email
Mobile string `json:"mobile"` // 手机号
Email string `json:"email"` // 邮箱地址
Code string `json:"code"` // 验证码
Password string `json:"password"` // 新密码
}
@@ -502,37 +581,47 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
return
}
session := h.DB.Session(&gorm.Session{})
var key string
if data.Type == "email" {
session = session.Where("email", data.Email)
key = CodeStorePrefix + data.Email
} else if data.Type == "mobile" {
session = session.Where("mobile", data.Email)
key = CodeStorePrefix + data.Mobile
} else {
resp.ERROR(c, "验证类别错误")
return
}
var user model.User
res := h.DB.Where("username", data.Username).First(&user)
if res.Error != nil {
err := session.First(&user).Error
if err != nil {
resp.ERROR(c, "用户不存在!")
return
}
// 检查验证码
key := CodeStorePrefix + data.Username
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "短信验证码错误")
resp.ERROR(c, "验证码错误")
return
}
password := utils.GenPassword(data.Password, user.Salt)
user.Password = password
res = h.DB.Updates(&user)
if res.Error != nil {
resp.ERROR(c)
err = h.DB.Model(&user).UpdateColumn("password", password).Error
if err != nil {
resp.ERROR(c, err.Error())
} else {
h.redis.Del(c, key)
resp.SUCCESS(c)
}
}
// BindUsername 重置账
func (h *UserHandler) BindUsername(c *gin.Context) {
// BindMobile 绑定手机
func (h *UserHandler) BindMobile(c *gin.Context) {
var data struct {
Username string `json:"username"`
Code string `json:"code"`
Mobile string `json:"mobile"`
Code string `json:"code"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -540,7 +629,7 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
}
// 检查验证码
key := CodeStorePrefix + data.Username
key := CodeStorePrefix + data.Mobile
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误")
@@ -549,19 +638,54 @@ func (h *UserHandler) BindUsername(c *gin.Context) {
// 检查手机号是否被其他账号绑定
var item model.User
res := h.DB.Where("username = ?", data.Username).First(&item)
res := h.DB.Where("mobile", data.Mobile).First(&item)
if res.Error == nil {
resp.ERROR(c, "该号已经其他账号绑定")
resp.ERROR(c, "该手机号已经绑定了其他账号,请更换手机号")
return
}
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
userId := h.GetLoginUserId(c)
err = h.DB.Model(&user).UpdateColumn("username", data.Username).Error
err = h.DB.Model(&item).Where("id", userId).UpdateColumn("mobile", data.Mobile).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
_ = h.redis.Del(c, key) // 删除短信验证码
resp.SUCCESS(c)
}
// BindEmail 绑定邮箱
func (h *UserHandler) BindEmail(c *gin.Context) {
var data struct {
Email string `json:"email"`
Code string `json:"code"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 检查验证码
key := CodeStorePrefix + data.Email
code, err := h.redis.Get(c, key).Result()
if err != nil || code != data.Code {
resp.ERROR(c, "验证码错误")
return
}
// 检查手机号是否被其他账号绑定
var item model.User
res := h.DB.Where("email", data.Email).First(&item)
if res.Error == nil {
resp.ERROR(c, "该邮箱地址已经绑定了其他账号,请更邮箱地址")
return
}
userId := h.GetLoginUserId(c)
err = h.DB.Model(&item).Where("id", userId).UpdateColumn("email", data.Email).Error
if err != nil {
resp.ERROR(c, err.Error())
return

View File

@@ -0,0 +1,233 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/service/oss"
"geekai/service/video"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
)
type VideoHandler struct {
BaseHandler
videoService *video.Service
uploader *oss.UploaderManager
userService *service.UserService
}
func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, uploader *oss.UploaderManager, userService *service.UserService) *VideoHandler {
return &VideoHandler{
BaseHandler: BaseHandler{
App: app,
DB: db,
},
videoService: service,
uploader: uploader,
userService: userService,
}
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *VideoHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.videoService.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *VideoHandler) LumaCreate(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
FirstFrameImg string `json:"first_frame_img,omitempty"`
EndFrameImg string `json:"end_frame_img,omitempty"`
ExpandPrompt bool `json:"expand_prompt,omitempty"`
Loop bool `json:"loop,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Prompt == "" {
resp.ERROR(c, "prompt is needed")
return
}
userId := int(h.GetLoginUserId(c))
params := types.VideoParams{
PromptOptimize: data.ExpandPrompt,
Loop: data.Loop,
StartImgURL: data.FirstFrameImg,
EndImgURL: data.EndFrameImg,
}
// 插入数据库
job := model.VideoJob{
UserId: userId,
Type: types.VideoLuma,
Prompt: data.Prompt,
Power: h.App.SysConfig.LumaPower,
Params: utils.JsonEncode(params),
}
tx := h.DB.Create(&job)
if tx.Error != nil {
resp.ERROR(c, tx.Error.Error())
return
}
// 创建任务
h.videoService.PushTask(types.VideoTask{
Id: job.Id,
UserId: userId,
Type: types.VideoLuma,
Prompt: data.Prompt,
Params: params,
})
// update user's power
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "luma",
Remark: fmt.Sprintf("Luma 文生视频任务ID%d", job.Id),
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
client := h.videoService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
func (h *VideoHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
t := c.Query("type")
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
all := h.GetBool(c, "all")
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
if t != "" {
session = session.Where("type", t)
}
if all {
session = session.Where("publish", 0).Where("progress", 100)
} else {
session = session.Where("user_id", h.GetLoginUserId(c))
}
// 统计总数
var total int64
session.Model(&model.VideoJob{}).Count(&total)
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
}
var list []model.VideoJob
err := session.Order("id desc").Find(&list).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 转换为 VO
items := make([]vo.VideoJob, 0)
for _, v := range list {
var item vo.VideoJob
err = utils.CopyObject(v, &item)
if err != nil {
continue
}
item.CreatedAt = v.CreatedAt.Unix()
items = append(items, item)
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
}
func (h *VideoHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
var job model.VideoJob
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 删除任务
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
// 如果任务未完成,或者任务失败,则恢复用户算力
if job.Progress != 100 {
err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "luma",
Remark: fmt.Sprintf("Luma 任务失败退回算力。任务ID%sErr:%s", job.TaskId, job.ErrMsg),
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
// 删除文件
_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
_ = h.uploader.GetUploadHandler().Delete(job.VideoURL)
}
func (h *VideoHandler) Publish(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
publish := h.GetBool(c, "publish")
var job model.VideoJob
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
err = h.DB.Model(&job).UpdateColumn("publish", publish).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}

View File

@@ -24,6 +24,7 @@ import (
"geekai/service/sd"
"geekai/service/sms"
"geekai/service/suno"
"geekai/service/video"
"geekai/store"
"io"
"log"
@@ -128,7 +129,7 @@ func main() {
fx.Provide(handler.NewChatRoleHandler),
fx.Provide(handler.NewUserHandler),
fx.Provide(chatimpl.NewChatHandler),
fx.Provide(handler.NewUploadHandler),
fx.Provide(handler.NewNetHandler),
fx.Provide(handler.NewSmsHandler),
fx.Provide(handler.NewRedeemHandler),
fx.Provide(handler.NewCaptchaHandler),
@@ -199,9 +200,16 @@ func main() {
s.Run()
s.SyncTaskProgress()
s.CheckTaskNotify()
s.DownloadImages()
s.DownloadFiles()
}),
fx.Provide(video.NewService),
fx.Invoke(func(s *video.Service) {
s.Run()
s.SyncTaskProgress()
s.CheckTaskNotify()
s.DownloadFiles()
}),
fx.Provide(service.NewUserService),
fx.Provide(payment.NewAlipayService),
fx.Provide(payment.NewHuPiPay),
fx.Provide(payment.NewJPayService),
@@ -231,7 +239,8 @@ func main() {
group.GET("profile", h.Profile)
group.POST("profile/update", h.ProfileUpdate)
group.POST("password", h.UpdatePass)
group.POST("bind/username", h.BindUsername)
group.POST("bind/mobile", h.BindMobile)
group.POST("bind/email", h.BindEmail)
group.POST("resetPass", h.ResetPass)
group.GET("clogin", h.CLogin)
group.GET("clogin/callback", h.CLoginCallback)
@@ -248,10 +257,11 @@ func main() {
group.POST("tokens", h.Tokens)
group.GET("stop", h.StopGenerate)
}),
fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) {
fx.Invoke(func(s *core.AppServer, h *handler.NetHandler) {
s.Engine.POST("/api/upload", h.Upload)
s.Engine.POST("/api/upload/list", h.List)
s.Engine.GET("/api/upload/remove", h.Remove)
s.Engine.GET("/api/download", h.Download)
}),
fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
group := s.Engine.Group("/api/sms/")
@@ -398,7 +408,7 @@ func main() {
fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) {
group := s.Engine.Group("/api/invite/")
group.GET("code", h.Code)
group.POST("list", h.List)
group.GET("list", h.List)
group.GET("hits", h.Hits)
}),
@@ -423,6 +433,7 @@ func main() {
group.POST("weibo", h.WeiBo)
group.POST("zaobao", h.ZaoBao)
group.POST("dalle3", h.Dall3)
group.GET("list", h.List)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
group := s.Engine.Group("/api/admin/chat/")
@@ -482,6 +493,15 @@ func main() {
group.GET("play", h.Play)
group.POST("lyric", h.Lyric)
}),
fx.Provide(handler.NewVideoHandler),
fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
group := s.Engine.Group("/api/video")
group.Any("client", h.Client)
group.POST("luma/create", h.LumaCreate)
group.GET("list", h.List)
group.GET("remove", h.Remove)
group.GET("publish", h.Publish)
}),
fx.Provide(handler.NewTestHandler),
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
group := s.Engine.Group("/api/test")

View File

@@ -35,9 +35,10 @@ type Service struct {
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
userService *service.UserService
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client, userService *service.UserService) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
@@ -45,6 +46,7 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
uploadManager: manager,
userService: userService,
}
}
@@ -122,32 +124,23 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
return "", errors.New("insufficient of power")
}
// 更新用户算力
tx := s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
s.db.Where("id", user.Id).First(&u)
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: task.Power,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
CreatedAt: time.Now(),
})
// 扣减算力
err := s.userService.DecreasePower(int(user.Id), task.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
})
if err != nil {
return "", fmt.Errorf("error with decrease power: %v", err)
}
// get image generation API KEY
var apiKey model.ApiKey
tx = s.db.Where("type", "dalle").
err = s.db.Where("type", "dalle").
Where("enabled", true).
Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
return "", fmt.Errorf("no available DALL-E api key: %v", tx.Error)
Order("last_used_at ASC").First(&apiKey).Error
if err != nil {
return "", fmt.Errorf("no available DALL-E api key: %v", err)
}
var res imgRes
@@ -181,13 +174,13 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
// update the api key last use time
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// update task progress
tx = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
err = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
"progress": 100,
"org_url": res.Data[0].Url,
"prompt": prompt,
})
if tx.Error != nil {
return "", fmt.Errorf("err with update database: %v", tx.Error)
}).Error
if err != nil {
return "", fmt.Errorf("err with update database: %v", err)
}
s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})

View File

@@ -82,14 +82,21 @@ func (s *Service) Run() {
logger.Errorf("taking task with error: %v", err)
continue
}
r, err := s.Create(task)
var r RespVo
if task.Type == 3 && task.SongId != "" { // 歌曲拼接
r, err = s.Merge(task)
} else if task.Type == 4 && task.AudioURL != "" { // 上传歌曲
r, err = s.Upload(task)
} else { // 歌曲创作
r, err = s.Create(task)
}
if err != nil {
logger.Errorf("create task with error: %v", err)
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"err_msg": err.Error(),
"progress": service.FailTaskProgress,
})
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
continue
}
@@ -138,7 +145,7 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
}
var res RespVo
apiURL := fmt.Sprintf("%s/task/suno/v1/submit/music", apiKey.ApiURL)
apiURL := fmt.Sprintf("%s/suno/submit/music", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
@@ -164,6 +171,97 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
return res, nil
}
func (s *Service) Merge(task types.SunoTask) (RespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return RespVo{}, errors.New("no available API KEY for Suno")
}
reqBody := map[string]interface{}{
"clip_id": task.SongId,
"is_infill": false,
}
var res RespVo
apiURL := fmt.Sprintf("%s/suno/submit/concat", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return RespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
if res.Code != "success" {
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) Upload(task types.SunoTask) (RespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "suno").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return RespVo{}, errors.New("no available API KEY for Suno")
}
reqBody := map[string]interface{}{
"url": task.AudioURL,
}
var res RespVo
apiURL := fmt.Sprintf("%s/suno/uploads/audio-url", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return RespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.StatusCode != 200 {
return RespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return RespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
if res.Code != "success" {
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Suno task notify checking ...")
@@ -185,7 +283,7 @@ func (s *Service) CheckTaskNotify() {
}()
}
func (s *Service) DownloadImages() {
func (s *Service) DownloadFiles() {
go func() {
var items []model.SunoJob
for {
@@ -331,15 +429,15 @@ type QueryRespVo struct {
func (s *Service) QueryTask(taskId string, channel string) (QueryRespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
tx := s.db.Session(&gorm.Session{}).Where("type", "suno").
err := s.db.Session(&gorm.Session{}).Where("type", "suno").
Where("api_url", channel).
Where("enabled", true).
Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
Order("last_used_at DESC").First(&apiKey).Error
if err != nil {
return QueryRespVo{}, errors.New("no available API KEY for Suno")
}
apiURL := fmt.Sprintf("%s/task/suno/v1/fetch/%s", apiKey.ApiURL, taskId)
apiURL := fmt.Sprintf("%s/suno/fetch/%s", apiKey.ApiURL, taskId)
var res QueryRespVo
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)

View File

@@ -0,0 +1,83 @@
package service
import (
"fmt"
"geekai/core/types"
"geekai/store/model"
"gorm.io/gorm"
"sync"
"time"
)
type UserService struct {
db *gorm.DB
lock sync.Mutex
}
func NewUserService(db *gorm.DB) *UserService {
return &UserService{db: db, lock: sync.Mutex{}}
}
// IncreasePower 增加用户算力
func (s *UserService) IncreasePower(userId int, power int, log model.PowerLog) error {
s.lock.Lock()
defer s.lock.Unlock()
tx := s.db.Begin()
err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power + ?", power)).Error
if err != nil {
tx.Rollback()
return err
}
var user model.User
tx.Where("id", userId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: log.Type,
Amount: power,
Balance: user.Power,
Mark: types.PowerAdd,
Model: log.Model,
Remark: log.Remark,
CreatedAt: time.Now(),
}).Error
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
return nil
}
// DecreasePower 减少用户算力
func (s *UserService) DecreasePower(userId int, power int, log model.PowerLog) error {
s.lock.Lock()
defer s.lock.Unlock()
tx := s.db.Begin()
err := tx.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", power)).Error
if err != nil {
tx.Rollback()
return fmt.Errorf("扣减算力失败:%v", err)
}
var user model.User
tx.Where("id", userId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: log.Type,
Amount: power,
Balance: user.Power,
Mark: types.PowerSub,
Model: log.Model,
Remark: log.Remark,
CreatedAt: time.Now(),
}).Error
if err != nil {
tx.Rollback()
return fmt.Errorf("记录算力日志失败:%v", err)
}
tx.Commit()
return nil
}

330
api/service/video/luma.go Normal file
View File

@@ -0,0 +1,330 @@
package video
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"github.com/go-redis/redis/v8"
"io"
"time"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
type Service struct {
httpClient *req.Client
db *gorm.DB
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("Video_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("Video_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
uploadManager: manager,
}
}
func (s *Service) PushTask(task types.VideoTask) {
logger.Infof("add a new Video task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) Run() {
// 将数据库中未提交的人物加载到队列
var jobs []model.VideoJob
s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
for _, v := range jobs {
var params types.VideoParams
if err := utils.JsonDecode(v.Params, &params); err != nil {
logger.Errorf("unmarshal params failed: %v", err)
continue
}
s.PushTask(types.VideoTask{
Id: v.Id,
Channel: v.Channel,
UserId: v.UserId,
Type: v.Type,
TaskId: v.TaskId,
Prompt: v.Prompt,
Params: params,
})
}
logger.Info("Starting Video job consumer...")
go func() {
for {
var task types.VideoTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
var r LumaRespVo
r, err = s.LumaCreate(task)
if err != nil {
logger.Errorf("create task with error: %v", err)
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"err_msg": err.Error(),
"progress": service.FailTaskProgress,
"cover_url": "/images/failed.jpg",
}).Error
if err != nil {
logger.Errorf("update task with error: %v", err)
}
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(task.Id), Message: service.TaskStatusFailed})
continue
}
// 更新任务信息
err = s.db.Model(&model.VideoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"task_id": r.Id,
"channel": r.Channel,
"prompt_ext": r.Prompt,
}).Error
if err != nil {
logger.Errorf("update task with error: %v", err)
s.PushTask(task)
}
}
}()
}
type LumaRespVo struct {
Id string `json:"id"`
Prompt string `json:"prompt"`
State string `json:"state"`
CreatedAt time.Time `json:"created_at"`
Video interface{} `json:"video"`
Liked interface{} `json:"liked"`
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
Channel string `json:"channel,omitempty"`
}
func (s *Service) LumaCreate(task types.VideoTask) (LumaRespVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
session := s.db.Session(&gorm.Session{}).Where("type", "luma").Where("enabled", true)
if task.Channel != "" {
session = session.Where("api_url", task.Channel)
}
tx := session.Order("last_used_at DESC").First(&apiKey)
if tx.Error != nil {
return LumaRespVo{}, errors.New("no available API KEY for Luma")
}
reqBody := map[string]interface{}{
"user_prompt": task.Prompt,
"expand_prompt": task.Params.PromptOptimize,
"loop": task.Params.Loop,
"image_url": task.Params.StartImgURL,
"image_end_url": task.Params.EndImgURL,
}
var res LumaRespVo
apiURL := fmt.Sprintf("%s/luma/generations", apiKey.ApiURL)
logger.Debugf("API URL: %s, request body: %+v", apiURL, reqBody)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(reqBody).
Post(apiURL)
if err != nil {
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.StatusCode != 200 && r.StatusCode != 201 {
return LumaRespVo{}, fmt.Errorf("请求 API 出错:%d, %s", r.StatusCode, r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return LumaRespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL
return res, nil
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Suno task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := s.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (s *Service) DownloadFiles() {
go func() {
var items []model.VideoJob
for {
res := s.db.Where("progress", 102).Find(&items)
if res.Error != nil {
continue
}
for _, v := range items {
if v.WaterURL == "" {
continue
}
logger.Infof("try download video: %s", v.WaterURL)
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true)
if err != nil {
logger.Errorf("download video with error: %v", err)
continue
}
logger.Infof("download video success: %s", videoURL)
v.WaterURL = videoURL
if v.VideoURL != "" {
logger.Infof("try download no water video: %s", v.VideoURL)
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
if err != nil {
logger.Errorf("download video with error: %v", err)
continue
}
}
logger.Info("download no water video success: %s", videoURL)
v.VideoURL = videoURL
v.Progress = 100
s.db.Updates(&v)
s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
}
time.Sleep(time.Second * 10)
}
}()
}
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.VideoJob
for {
res := s.db.Where("progress < ?", 100).Where("task_id <> ?", "").Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range jobs {
task, err := s.QueryLumaTask(job.TaskId, job.Channel)
if err != nil {
logger.Errorf("query task with error: %v", err)
// 更新任务信息
s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress, // 102 表示资源未下载完成,
"err_msg": err.Error(),
})
continue
}
logger.Debugf("task: %+v", task)
if task.State == "completed" { // 更新任务信息
data := map[string]interface{}{
"progress": 102, // 102 表示资源未下载完成,
"water_url": task.Video.Url,
"raw_data": utils.JsonEncode(task),
"prompt_ext": task.Prompt,
}
if task.Video.DownloadUrl != "" {
data["video_url"] = task.Video.DownloadUrl
}
err = s.db.Model(&model.VideoJob{Id: job.Id}).UpdateColumns(data).Error
if err != nil {
logger.Errorf("更新数据库失败:%v", err)
continue
}
}
}
time.Sleep(time.Second * 10)
}
}()
}
type LumaTaskVo struct {
Id string `json:"id"`
Liked interface{} `json:"liked"`
State string `json:"state"`
Video struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
DownloadUrl string `json:"download_url"`
} `json:"video"`
Prompt string `json:"prompt"`
CreatedAt time.Time `json:"created_at"`
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
}
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
// 读取 API KEY
var apiKey model.ApiKey
err := s.db.Session(&gorm.Session{}).Where("type", "luma").
Where("api_url", channel).
Where("enabled", true).
Order("last_used_at DESC").First(&apiKey).Error
if err != nil {
return LumaTaskVo{}, errors.New("no available API KEY for Luma")
}
apiURL := fmt.Sprintf("%s/luma/generations/%s", apiKey.ApiURL, taskId)
var res LumaTaskVo
r, err := req.C().R().SetHeader("Authorization", "Bearer "+apiKey.Value).Get(apiURL)
if err != nil {
return LumaTaskVo{}, fmt.Errorf("请求 API 失败:%v", err)
}
defer r.Body.Close()
if r.StatusCode != 200 {
return LumaTaskVo{}, fmt.Errorf("API 返回失败:%v", r.String())
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &res)
if err != nil {
return LumaTaskVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
return res, nil
}

View File

@@ -81,54 +81,6 @@ func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (ms
// 自动将 VIP 会员的算力补充到每月赠送的最大值
func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (msg string) {
logger.Info("开始进行月底账号盘点...")
var users []model.User
res := e.db.Where("vip", 1).Where("status", 1).Find(&users)
if res.Error != nil {
return "No vip users found"
}
var sysConfig model.Config
res = e.db.Where("marker", "system").First(&sysConfig)
if res.Error != nil {
return "error with get system config: " + res.Error.Error()
}
var config types.SystemConfig
err := utils.JsonDecode(sysConfig.Config, &config)
if err != nil {
return "error with decode system config: " + err.Error()
}
for _, u := range users {
// 处理过期的 VIP
if u.ExpiredTime > 0 && u.ExpiredTime <= time.Now().Unix() {
u.Vip = false
e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("vip", false)
continue
}
if u.Power < config.VipMonthPower {
power := config.VipMonthPower - u.Power
// update user
tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", power))
// 记录算力变动日志
if tx.Error == nil {
var user model.User
e.db.Where("id", u.Id).First(&user)
e.db.Create(&model.PowerLog{
UserId: u.Id,
Username: u.Username,
Type: types.PowerRecharge,
Amount: power,
Mark: types.PowerAdd,
Balance: user.Power,
Model: "系统盘点",
Remark: fmt.Sprintf("VIP会员每月算力派发%d", config.VipMonthPower),
CreatedAt: time.Now(),
})
}
}
}
logger.Info("月底盘点完成!")
return "success"
}

View File

@@ -29,15 +29,9 @@ func NewLevelDB() (*LevelDB, error) {
}
func (db *LevelDB) Put(key string, value interface{}) error {
var byteData []byte
if v, ok := value.(string); ok {
byteData = []byte(v)
} else {
b, err := json.Marshal(value)
if err != nil {
return err
}
byteData = b
byteData, err := json.Marshal(value)
if err != nil {
return err
}
return db.driver.Put([]byte(key), byteData, nil)
}

View File

@@ -4,6 +4,8 @@ type User struct {
BaseModel
Username string
Nickname string
Email string
Mobile string
Password string
Avatar string
Salt string // 密码盐

View File

@@ -0,0 +1,27 @@
package model
import "time"
type VideoJob struct {
Id uint `gorm:"primarykey;column:id"`
UserId int
Channel string // 频道
Type string // luma,runway,cog
TaskId string
Prompt string // 提示词
PromptExt string // 优化后提示词
CoverURL string // 封面图 URL
VideoURL string // 无水印视频 URL
WaterURL string // 有水印视频 URL
Progress int // 任务进度
Publish bool // 是否发布
ErrMsg string // 错误信息
RawData string // 原始数据 json
Power int // 消耗算力
Params string // 任务参数
CreatedAt time.Time
}
func (VideoJob) TableName() string {
return "chatgpt_video_jobs"
}

View File

@@ -5,7 +5,7 @@ type SunoJob struct {
UserId int `json:"user_id"`
Channel string `json:"channel"`
Title string `json:"title"`
Type string `json:"type"`
Type int `json:"type"`
TaskId string `json:"task_id"`
RefTaskId string `json:"ref_task_id"` // 续写的任务id
Tags string `json:"tags"` // 歌曲风格和标签
@@ -28,7 +28,3 @@ type SunoJob struct {
PlayTimes int `json:"play_times"` // 播放次数
CreatedAt int64 `json:"created_at"`
}
func (SunoJob) TableName() string {
return "chatgpt_suno_jobs"
}

View File

@@ -4,6 +4,8 @@ type User struct {
BaseVo
Username string `json:"username"`
Nickname string `json:"nickname"`
Mobile string `json:"mobile"`
Email string `json:"email"`
Avatar string `json:"avatar"`
Salt string `json:"salt"` // 密码盐
Power int `json:"power"` // 剩余算力

23
api/store/vo/video_job.go Normal file
View File

@@ -0,0 +1,23 @@
package vo
import "geekai/core/types"
type VideoJob struct {
Id uint `json:"id"`
UserId int `json:"user_id"`
Channel string `json:"channel"`
Type string `json:"type"`
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
PromptExt string `json:"prompt_ext"` // 提示词
CoverURL string `json:"cover_url"` // 封面图 URL
VideoURL string `json:"video_url"` // 无水印视频 URL
WaterURL string `json:"water_url"` // 有水印视频 URL
Progress int `json:"progress"` // 任务进度
Publish bool `json:"publish"` // 是否发布
ErrMsg string `json:"err_msg"` // 错误信息
RawData map[string]interface{} `json:"raw_data"` // 原始数据 json
Power int `json:"power"` // 消耗算力
Params types.VideoParams `json:"params"` // 任务参数
CreatedAt int64 `json:"created_at"`
}