mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-09 02:33:42 +08:00
merge v4.1.6
This commit is contained in:
@@ -143,65 +143,67 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
|
||||
|
||||
// FixData 修复数据
|
||||
func (h *ConfigHandler) FixData(c *gin.Context) {
|
||||
var fixed bool
|
||||
version := "data_fix_4.1.4"
|
||||
err := h.levelDB.Get(version, &fixed)
|
||||
if err == nil || fixed {
|
||||
resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作")
|
||||
return
|
||||
}
|
||||
tx := h.DB.Begin()
|
||||
var users []model.User
|
||||
err = tx.Find(&users).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
for _, user := range users {
|
||||
if user.Email != "" || user.Mobile != "" {
|
||||
continue
|
||||
}
|
||||
if utils.IsValidEmail(user.Username) {
|
||||
user.Email = user.Username
|
||||
} else if utils.IsValidMobile(user.Username) {
|
||||
user.Mobile = user.Username
|
||||
}
|
||||
err = tx.Save(&user).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var orders []model.Order
|
||||
err = h.DB.Find(&orders).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
for _, order := range orders {
|
||||
if order.PayWay == "支付宝" {
|
||||
order.PayWay = "alipay"
|
||||
order.PayType = "alipay"
|
||||
} else if order.PayWay == "微信支付" {
|
||||
order.PayWay = "wechat"
|
||||
order.PayType = "wxpay"
|
||||
} else if order.PayWay == "hupi" {
|
||||
order.PayType = "wxpay"
|
||||
}
|
||||
err = tx.Save(&order).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
}
|
||||
tx.Commit()
|
||||
err = h.levelDB.Put(version, true)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
resp.ERROR(c, "当前升级版本没有数据需要修正!")
|
||||
return
|
||||
//var fixed bool
|
||||
//version := "data_fix_4.1.4"
|
||||
//err := h.levelDB.Get(version, &fixed)
|
||||
//if err == nil || fixed {
|
||||
// resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作")
|
||||
// return
|
||||
//}
|
||||
//tx := h.DB.Begin()
|
||||
//var users []model.User
|
||||
//err = tx.Find(&users).Error
|
||||
//if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// return
|
||||
//}
|
||||
//for _, user := range users {
|
||||
// if user.Email != "" || user.Mobile != "" {
|
||||
// continue
|
||||
// }
|
||||
// if utils.IsValidEmail(user.Username) {
|
||||
// user.Email = user.Username
|
||||
// } else if utils.IsValidMobile(user.Username) {
|
||||
// user.Mobile = user.Username
|
||||
// }
|
||||
// err = tx.Save(&user).Error
|
||||
// if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// tx.Rollback()
|
||||
// return
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//var orders []model.Order
|
||||
//err = h.DB.Find(&orders).Error
|
||||
//if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// return
|
||||
//}
|
||||
//for _, order := range orders {
|
||||
// if order.PayWay == "支付宝" {
|
||||
// order.PayWay = "alipay"
|
||||
// order.PayType = "alipay"
|
||||
// } else if order.PayWay == "微信支付" {
|
||||
// order.PayWay = "wechat"
|
||||
// order.PayType = "wxpay"
|
||||
// } else if order.PayWay == "hupi" {
|
||||
// order.PayType = "wxpay"
|
||||
// }
|
||||
// err = tx.Save(&order).Error
|
||||
// if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// tx.Rollback()
|
||||
// return
|
||||
// }
|
||||
//}
|
||||
//tx.Commit()
|
||||
//err = h.levelDB.Put(version, true)
|
||||
//if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// return
|
||||
//}
|
||||
//resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
254
api/handler/admin/image_handler.go
Normal file
254
api/handler/admin/image_handler.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ImageHandler struct {
|
||||
handler.BaseHandler
|
||||
userService *service.UserService
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
func NewImageHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *ImageHandler {
|
||||
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
|
||||
}
|
||||
|
||||
type imageQuery struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Username string `json:"username"`
|
||||
CreatedAt []string `json:"created_at"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
// MjList Midjourney 任务列表
|
||||
func (h *ImageHandler) MjList(c *gin.Context) {
|
||||
var data imageQuery
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if data.Username != "" {
|
||||
var user model.User
|
||||
err := h.DB.Where("username", data.Username).First(&user).Error
|
||||
if err == nil {
|
||||
session = session.Where("user_id", user.Id)
|
||||
}
|
||||
}
|
||||
if data.Prompt != "" {
|
||||
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
|
||||
}
|
||||
if len(data.CreatedAt) == 2 {
|
||||
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
|
||||
}
|
||||
var total int64
|
||||
session.Model(&model.MidJourneyJob{}).Count(&total)
|
||||
var list []model.MidJourneyJob
|
||||
var items = make([]vo.MidJourneyJob, 0)
|
||||
offset := (data.Page - 1) * data.PageSize
|
||||
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
|
||||
if err == nil {
|
||||
// 填充数据
|
||||
for _, item := range list {
|
||||
var job vo.MidJourneyJob
|
||||
err = utils.CopyObject(item, &job)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
job.CreatedAt = item.CreatedAt.Unix()
|
||||
items = append(items, job)
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
|
||||
}
|
||||
|
||||
// SdList Stable Diffusion 任务列表
|
||||
func (h *ImageHandler) SdList(c *gin.Context) {
|
||||
var data imageQuery
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if data.Username != "" {
|
||||
var user model.User
|
||||
err := h.DB.Where("username", data.Username).First(&user).Error
|
||||
if err == nil {
|
||||
session = session.Where("user_id", user.Id)
|
||||
}
|
||||
}
|
||||
if data.Prompt != "" {
|
||||
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
|
||||
}
|
||||
if len(data.CreatedAt) == 2 {
|
||||
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
|
||||
}
|
||||
var total int64
|
||||
session.Model(&model.SdJob{}).Count(&total)
|
||||
var list []model.SdJob
|
||||
var items = make([]vo.SdJob, 0)
|
||||
offset := (data.Page - 1) * data.PageSize
|
||||
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
|
||||
if err == nil {
|
||||
// 填充数据
|
||||
for _, item := range list {
|
||||
var job vo.SdJob
|
||||
err = utils.CopyObject(item, &job)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
job.CreatedAt = item.CreatedAt.Unix()
|
||||
items = append(items, job)
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
|
||||
}
|
||||
|
||||
// DallList DALL-E 任务列表
|
||||
func (h *ImageHandler) DallList(c *gin.Context) {
|
||||
var data imageQuery
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if data.Username != "" {
|
||||
var user model.User
|
||||
err := h.DB.Where("username", data.Username).First(&user).Error
|
||||
if err == nil {
|
||||
session = session.Where("user_id", user.Id)
|
||||
}
|
||||
}
|
||||
if data.Prompt != "" {
|
||||
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
|
||||
}
|
||||
if len(data.CreatedAt) == 2 {
|
||||
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
|
||||
}
|
||||
var total int64
|
||||
session.Model(&model.DallJob{}).Count(&total)
|
||||
var list []model.DallJob
|
||||
var items = make([]vo.DallJob, 0)
|
||||
offset := (data.Page - 1) * data.PageSize
|
||||
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
|
||||
if err == nil {
|
||||
// 填充数据
|
||||
for _, item := range list {
|
||||
var job vo.DallJob
|
||||
err = utils.CopyObject(item, &job)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
job.CreatedAt = item.CreatedAt.Unix()
|
||||
items = append(items, job)
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
|
||||
}
|
||||
|
||||
func (h *ImageHandler) Remove(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
tab := c.Query("tab")
|
||||
|
||||
tx := h.DB.Begin()
|
||||
var md, remark, imgURL string
|
||||
var power, userId, progress int
|
||||
switch tab {
|
||||
case "mj":
|
||||
var job model.MidJourneyJob
|
||||
if err := h.DB.Where("id", id).First(&job).Error; err != nil {
|
||||
resp.ERROR(c, "记录不存在")
|
||||
return
|
||||
}
|
||||
tx.Delete(&job)
|
||||
md = "mid-journey"
|
||||
power = job.Power
|
||||
userId = job.UserId
|
||||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
imgURL = job.ImgURL
|
||||
break
|
||||
case "sd":
|
||||
var job model.SdJob
|
||||
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||
resp.ERROR(c, "记录不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 删除任务
|
||||
tx.Delete(&job)
|
||||
md = "stable-diffusion"
|
||||
power = job.Power
|
||||
userId = job.UserId
|
||||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
imgURL = job.ImgURL
|
||||
break
|
||||
case "dall":
|
||||
var job model.DallJob
|
||||
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||
resp.ERROR(c, "记录不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 删除任务
|
||||
tx.Delete(&job)
|
||||
md = "dall-e-3"
|
||||
power = job.Power
|
||||
userId = int(job.UserId)
|
||||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
imgURL = job.ImgURL
|
||||
break
|
||||
default:
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if progress != 100 {
|
||||
err := h.userService.IncreasePower(userId, power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: md,
|
||||
Remark: remark,
|
||||
})
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
tx.Commit()
|
||||
// remove image
|
||||
err := h.uploader.GetUploadHandler().Delete(imgURL)
|
||||
if err != nil {
|
||||
logger.Error("remove image failed: ", err)
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
200
api/handler/admin/media_handler.go
Normal file
200
api/handler/admin/media_handler.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package admin
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type MediaHandler struct {
|
||||
handler.BaseHandler
|
||||
userService *service.UserService
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
func NewMediaHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *MediaHandler {
|
||||
return &MediaHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
|
||||
}
|
||||
|
||||
type mediaQuery struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Username string `json:"username"`
|
||||
CreatedAt []string `json:"created_at"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
// SunoList Suno 任务列表
|
||||
func (h *MediaHandler) SunoList(c *gin.Context) {
|
||||
var data mediaQuery
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if data.Username != "" {
|
||||
var user model.User
|
||||
err := h.DB.Where("username", data.Username).First(&user).Error
|
||||
if err == nil {
|
||||
session = session.Where("user_id", user.Id)
|
||||
}
|
||||
}
|
||||
if data.Prompt != "" {
|
||||
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
|
||||
}
|
||||
if len(data.CreatedAt) == 2 {
|
||||
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
|
||||
}
|
||||
var total int64
|
||||
session.Model(&model.SunoJob{}).Count(&total)
|
||||
var list []model.SunoJob
|
||||
var items = make([]vo.SunoJob, 0)
|
||||
offset := (data.Page - 1) * data.PageSize
|
||||
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
|
||||
if err == nil {
|
||||
// 填充数据
|
||||
for _, item := range list {
|
||||
var job vo.SunoJob
|
||||
err = utils.CopyObject(item, &job)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
job.CreatedAt = item.CreatedAt.Unix()
|
||||
items = append(items, job)
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
|
||||
}
|
||||
|
||||
// LumaList Luma 视频任务列表
|
||||
func (h *MediaHandler) LumaList(c *gin.Context) {
|
||||
var data mediaQuery
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if data.Username != "" {
|
||||
var user model.User
|
||||
err := h.DB.Where("username", data.Username).First(&user).Error
|
||||
if err == nil {
|
||||
session = session.Where("user_id", user.Id)
|
||||
}
|
||||
}
|
||||
if data.Prompt != "" {
|
||||
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
|
||||
}
|
||||
if len(data.CreatedAt) == 2 {
|
||||
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
|
||||
}
|
||||
var total int64
|
||||
session.Model(&model.VideoJob{}).Count(&total)
|
||||
var list []model.VideoJob
|
||||
var items = make([]vo.VideoJob, 0)
|
||||
offset := (data.Page - 1) * data.PageSize
|
||||
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
|
||||
if err == nil {
|
||||
// 填充数据
|
||||
for _, item := range list {
|
||||
var job vo.VideoJob
|
||||
err = utils.CopyObject(item, &job)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
job.CreatedAt = item.CreatedAt.Unix()
|
||||
if job.VideoURL == "" {
|
||||
job.VideoURL = job.WaterURL
|
||||
}
|
||||
items = append(items, job)
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
|
||||
}
|
||||
|
||||
func (h *MediaHandler) Remove(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
tab := c.Query("tab")
|
||||
|
||||
tx := h.DB.Begin()
|
||||
var md, remark, fileURL string
|
||||
var power, userId, progress int
|
||||
switch tab {
|
||||
case "suno":
|
||||
var job model.SunoJob
|
||||
if err := h.DB.Where("id", id).First(&job).Error; err != nil {
|
||||
resp.ERROR(c, "记录不存在")
|
||||
return
|
||||
}
|
||||
tx.Delete(&job)
|
||||
md = "suno"
|
||||
power = job.Power
|
||||
userId = job.UserId
|
||||
remark = fmt.Sprintf("SUNO 任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
fileURL = job.AudioURL
|
||||
break
|
||||
case "luma":
|
||||
var job model.VideoJob
|
||||
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||
resp.ERROR(c, "记录不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 删除任务
|
||||
tx.Delete(&job)
|
||||
md = job.Type
|
||||
power = job.Power
|
||||
userId = job.UserId
|
||||
remark = fmt.Sprintf("LUMA 任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
fileURL = job.VideoURL
|
||||
if fileURL == "" {
|
||||
fileURL = job.WaterURL
|
||||
}
|
||||
break
|
||||
default:
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if progress != 100 {
|
||||
err := h.userService.IncreasePower(userId, power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: md,
|
||||
Remark: remark,
|
||||
})
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
tx.Commit()
|
||||
// remove image
|
||||
err := h.uploader.GetUploadHandler().Delete(fileURL)
|
||||
if err != nil {
|
||||
logger.Error("remove image failed: ", err)
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
@@ -146,19 +146,15 @@ func (h *RedeemHandler) Set(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *RedeemHandler) Remove(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
if id <= 0 {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
if data.Id > 0 {
|
||||
err := h.DB.Where("id", data.Id).Delete(&model.Redeem{}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
err := h.DB.Where("id", id).Delete(&model.Redeem{}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -180,12 +180,7 @@ func (h *DallJobHandler) Remove(c *gin.Context) {
|
||||
|
||||
// 删除任务
|
||||
tx := h.DB.Begin()
|
||||
if err := tx.Delete(&job).Error; err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tx.Delete(&job)
|
||||
// 如果任务未完成,或者任务失败,则恢复用户算力
|
||||
if job.Progress != 100 {
|
||||
err := h.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{
|
||||
|
||||
@@ -403,12 +403,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
||||
|
||||
// remove job recode
|
||||
tx := h.DB.Begin()
|
||||
if err := tx.Delete(&job).Error; err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tx.Delete(&job)
|
||||
// 如果任务未完成,或者任务失败,则恢复用户算力
|
||||
if job.Progress != 100 {
|
||||
err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"github.com/shopspring/decimal"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -105,7 +104,7 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
|
||||
amount := product.Discount
|
||||
var payURL, returnURL, notifyURL string
|
||||
switch data.PayWay {
|
||||
case "alipay":
|
||||
|
||||
128
api/handler/realtime_handler.go
Normal file
128
api/handler/realtime_handler.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/store/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
// OpenAI Realtime API Relay Server
|
||||
|
||||
type RealtimeHandler struct {
|
||||
BaseHandler
|
||||
}
|
||||
|
||||
func NewRealtimeHandler(server *core.AppServer, db *gorm.DB) *RealtimeHandler {
|
||||
return &RealtimeHandler{BaseHandler{App: server, DB: db}}
|
||||
}
|
||||
|
||||
func (h *RealtimeHandler) Connection(c *gin.Context) {
|
||||
// 获取客户端请求中指定的子协议
|
||||
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||
md := c.Query("model")
|
||||
|
||||
userId := h.GetLoginUserId(c)
|
||||
var user model.User
|
||||
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将 HTTP 协议升级为 Websocket 协议
|
||||
subProtocols := strings.Split(clientProtocols, ",")
|
||||
ws, err := (&websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
Subprotocols: subProtocols,
|
||||
}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
// 目前只针对 VIP 用户可以访问
|
||||
if !user.Vip {
|
||||
sendError(ws, "当前功能只针对 VIP 用户开放")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
var apiKey model.ApiKey
|
||||
h.DB.Where("type", "realtime").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
|
||||
if apiKey.Id == 0 {
|
||||
sendError(ws, "管理员未配置 Realtime API KEY")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/v1/realtime?model=%s", apiKey.ApiURL, md)
|
||||
// 连接到真实的后端服务器,传入相同的子协议
|
||||
headers := http.Header{}
|
||||
// 修正子协议内容
|
||||
subProtocols[1] = "openai-insecure-api-key." + apiKey.Value
|
||||
if clientProtocols != "" {
|
||||
headers.Set("Sec-WebSocket-Protocol", strings.Join(subProtocols, ","))
|
||||
}
|
||||
backendConn, _, err := websocket.DefaultDialer.Dial(apiURL, headers)
|
||||
if err != nil {
|
||||
sendError(ws, "桥接后端 API 失败:"+err.Error())
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
// 确保协议一致性,如果失败返回
|
||||
if ws.Subprotocol() != backendConn.Subprotocol() {
|
||||
sendError(ws, "Websocket 子协议不匹配")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 更新API KEY 最后使用时间
|
||||
h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
|
||||
// 开始双向转发
|
||||
errorChan := make(chan error, 2)
|
||||
go relay(ws, backendConn, errorChan)
|
||||
go relay(backendConn, ws, errorChan)
|
||||
|
||||
// 等待其中一个连接关闭
|
||||
err = <-errorChan
|
||||
logger.Infof("Relay ended: %v", err)
|
||||
}
|
||||
|
||||
func relay(src, dst *websocket.Conn, errorChan chan error) {
|
||||
for {
|
||||
messageType, message, err := src.ReadMessage()
|
||||
if err != nil {
|
||||
errorChan <- err
|
||||
return
|
||||
}
|
||||
err = dst.WriteMessage(messageType, message)
|
||||
if err != nil {
|
||||
errorChan <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sendError(ws *websocket.Conn, message string) {
|
||||
err := ws.WriteJSON(map[string]string{"event_id": "event_01", "type": "error", "error": message})
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
@@ -250,18 +250,13 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
|
||||
|
||||
// 删除任务
|
||||
tx := h.DB.Begin()
|
||||
if err := tx.Delete(&job).Error; err != nil {
|
||||
tx.Rollback()
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
tx.Delete(&job)
|
||||
// 如果任务未完成,或者任务失败,则恢复用户算力
|
||||
if job.Progress != 100 {
|
||||
err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "stable-diffusion",
|
||||
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%s, Err: %s", job.TaskId, job.ErrMsg),
|
||||
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d, Err: %s", job.Id, job.ErrMsg),
|
||||
})
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
|
||||
@@ -156,6 +156,9 @@ func (h *VideoHandler) List(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
item.CreatedAt = v.CreatedAt.Unix()
|
||||
if item.VideoURL == "" {
|
||||
item.VideoURL = v.WaterURL
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Websocket 连接处理 handler
|
||||
@@ -37,7 +38,11 @@ func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *g
|
||||
}
|
||||
|
||||
func (h *WebsocketHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||
ws, err := (&websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
Subprotocols: strings.Split(clientProtocols, ","),
|
||||
}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
|
||||
Reference in New Issue
Block a user