merge v4.1.6

This commit is contained in:
RockYang
2025-03-05 18:42:30 +08:00
71 changed files with 5678 additions and 258 deletions

View File

@@ -143,65 +143,67 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
// FixData 修复数据
func (h *ConfigHandler) FixData(c *gin.Context) {
var fixed bool
version := "data_fix_4.1.4"
err := h.levelDB.Get(version, &fixed)
if err == nil || fixed {
resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作")
return
}
tx := h.DB.Begin()
var users []model.User
err = tx.Find(&users).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
for _, user := range users {
if user.Email != "" || user.Mobile != "" {
continue
}
if utils.IsValidEmail(user.Username) {
user.Email = user.Username
} else if utils.IsValidMobile(user.Username) {
user.Mobile = user.Username
}
err = tx.Save(&user).Error
if err != nil {
resp.ERROR(c, err.Error())
tx.Rollback()
return
}
}
var orders []model.Order
err = h.DB.Find(&orders).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
for _, order := range orders {
if order.PayWay == "支付宝" {
order.PayWay = "alipay"
order.PayType = "alipay"
} else if order.PayWay == "微信支付" {
order.PayWay = "wechat"
order.PayType = "wxpay"
} else if order.PayWay == "hupi" {
order.PayType = "wxpay"
}
err = tx.Save(&order).Error
if err != nil {
resp.ERROR(c, err.Error())
tx.Rollback()
return
}
}
tx.Commit()
err = h.levelDB.Put(version, true)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
resp.ERROR(c, "当前升级版本没有数据需要修正!")
return
//var fixed bool
//version := "data_fix_4.1.4"
//err := h.levelDB.Get(version, &fixed)
//if err == nil || fixed {
// resp.ERROR(c, "当前版本数据修复已完成,请不要重复执行操作")
// return
//}
//tx := h.DB.Begin()
//var users []model.User
//err = tx.Find(&users).Error
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//for _, user := range users {
// if user.Email != "" || user.Mobile != "" {
// continue
// }
// if utils.IsValidEmail(user.Username) {
// user.Email = user.Username
// } else if utils.IsValidMobile(user.Username) {
// user.Mobile = user.Username
// }
// err = tx.Save(&user).Error
// if err != nil {
// resp.ERROR(c, err.Error())
// tx.Rollback()
// return
// }
//}
//
//var orders []model.Order
//err = h.DB.Find(&orders).Error
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//for _, order := range orders {
// if order.PayWay == "支付宝" {
// order.PayWay = "alipay"
// order.PayType = "alipay"
// } else if order.PayWay == "微信支付" {
// order.PayWay = "wechat"
// order.PayType = "wxpay"
// } else if order.PayWay == "hupi" {
// order.PayType = "wxpay"
// }
// err = tx.Save(&order).Error
// if err != nil {
// resp.ERROR(c, err.Error())
// tx.Rollback()
// return
// }
//}
//tx.Commit()
//err = h.levelDB.Put(version, true)
//if err != nil {
// resp.ERROR(c, err.Error())
// return
//}
//resp.SUCCESS(c)
}

View File

@@ -0,0 +1,254 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type ImageHandler struct {
handler.BaseHandler
userService *service.UserService
uploader *oss.UploaderManager
}
func NewImageHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *ImageHandler {
return &ImageHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
}
type imageQuery struct {
Prompt string `json:"prompt"`
Username string `json:"username"`
CreatedAt []string `json:"created_at"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// MjList Midjourney 任务列表
func (h *ImageHandler) MjList(c *gin.Context) {
var data imageQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.MidJourneyJob{}).Count(&total)
var list []model.MidJourneyJob
var items = make([]vo.MidJourneyJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.MidJourneyJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
// SdList Stable Diffusion 任务列表
func (h *ImageHandler) SdList(c *gin.Context) {
var data imageQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.SdJob{}).Count(&total)
var list []model.SdJob
var items = make([]vo.SdJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.SdJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
// DallList DALL-E 任务列表
func (h *ImageHandler) DallList(c *gin.Context) {
var data imageQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.DallJob{}).Count(&total)
var list []model.DallJob
var items = make([]vo.DallJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.DallJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
func (h *ImageHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
tab := c.Query("tab")
tx := h.DB.Begin()
var md, remark, imgURL string
var power, userId, progress int
switch tab {
case "mj":
var job model.MidJourneyJob
if err := h.DB.Where("id", id).First(&job).Error; err != nil {
resp.ERROR(c, "记录不存在")
return
}
tx.Delete(&job)
md = "mid-journey"
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
imgURL = job.ImgURL
break
case "sd":
var job model.SdJob
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// 删除任务
tx.Delete(&job)
md = "stable-diffusion"
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
imgURL = job.ImgURL
break
case "dall":
var job model.DallJob
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// 删除任务
tx.Delete(&job)
md = "dall-e-3"
power = job.Power
userId = int(job.UserId)
remark = fmt.Sprintf("任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
imgURL = job.ImgURL
break
default:
resp.ERROR(c, types.InvalidArgs)
return
}
if progress != 100 {
err := h.userService.IncreasePower(userId, power, model.PowerLog{
Type: types.PowerRefund,
Model: md,
Remark: remark,
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
// remove image
err := h.uploader.GetUploadHandler().Delete(imgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
resp.SUCCESS(c)
}

View File

@@ -0,0 +1,200 @@
package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
"geekai/service"
"geekai/service/oss"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type MediaHandler struct {
handler.BaseHandler
userService *service.UserService
uploader *oss.UploaderManager
}
func NewMediaHandler(app *core.AppServer, db *gorm.DB, userService *service.UserService, manager *oss.UploaderManager) *MediaHandler {
return &MediaHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, userService: userService, uploader: manager}
}
type mediaQuery struct {
Prompt string `json:"prompt"`
Username string `json:"username"`
CreatedAt []string `json:"created_at"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// SunoList Suno 任务列表
func (h *MediaHandler) SunoList(c *gin.Context) {
var data mediaQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.SunoJob{}).Count(&total)
var list []model.SunoJob
var items = make([]vo.SunoJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.SunoJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
// LumaList Luma 视频任务列表
func (h *MediaHandler) LumaList(c *gin.Context) {
var data mediaQuery
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
session := h.DB.Session(&gorm.Session{})
if data.Username != "" {
var user model.User
err := h.DB.Where("username", data.Username).First(&user).Error
if err == nil {
session = session.Where("user_id", user.Id)
}
}
if data.Prompt != "" {
session = session.Where("prompt LIKE ?", "%"+data.Prompt+"%")
}
if len(data.CreatedAt) == 2 {
session = session.Where("created_at >= ? AND created_at <= ?", data.CreatedAt[0], data.CreatedAt[1])
}
var total int64
session.Model(&model.VideoJob{}).Count(&total)
var list []model.VideoJob
var items = make([]vo.VideoJob, 0)
offset := (data.Page - 1) * data.PageSize
err := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&list).Error
if err == nil {
// 填充数据
for _, item := range list {
var job vo.VideoJob
err = utils.CopyObject(item, &job)
if err != nil {
continue
}
job.CreatedAt = item.CreatedAt.Unix()
if job.VideoURL == "" {
job.VideoURL = job.WaterURL
}
items = append(items, job)
}
}
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, items))
}
func (h *MediaHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
tab := c.Query("tab")
tx := h.DB.Begin()
var md, remark, fileURL string
var power, userId, progress int
switch tab {
case "suno":
var job model.SunoJob
if err := h.DB.Where("id", id).First(&job).Error; err != nil {
resp.ERROR(c, "记录不存在")
return
}
tx.Delete(&job)
md = "suno"
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("SUNO 任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
fileURL = job.AudioURL
break
case "luma":
var job model.VideoJob
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
resp.ERROR(c, "记录不存在")
return
}
// 删除任务
tx.Delete(&job)
md = job.Type
power = job.Power
userId = job.UserId
remark = fmt.Sprintf("LUMA 任务失败退回算力。任务ID%dErr: %s", job.Id, job.ErrMsg)
progress = job.Progress
fileURL = job.VideoURL
if fileURL == "" {
fileURL = job.WaterURL
}
break
default:
resp.ERROR(c, types.InvalidArgs)
return
}
if progress != 100 {
err := h.userService.IncreasePower(userId, power, model.PowerLog{
Type: types.PowerRefund,
Model: md,
Remark: remark,
})
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
}
tx.Commit()
// remove image
err := h.uploader.GetUploadHandler().Delete(fileURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
resp.SUCCESS(c)
}

View File

@@ -146,19 +146,15 @@ func (h *RedeemHandler) Set(c *gin.Context) {
}
func (h *RedeemHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
err := h.DB.Where("id", data.Id).Delete(&model.Redeem{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
err := h.DB.Where("id", id).Delete(&model.Redeem{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}

View File

@@ -180,12 +180,7 @@ func (h *DallJobHandler) Remove(c *gin.Context) {
// 删除任务
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
tx.Delete(&job)
// 如果任务未完成,或者任务失败,则恢复用户算力
if job.Progress != 100 {
err := h.userService.IncreasePower(int(job.UserId), job.Power, model.PowerLog{

View File

@@ -403,12 +403,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
// remove job recode
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
tx.Delete(&job)
// 如果任务未完成,或者任务失败,则恢复用户算力
if job.Progress != 100 {
err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{

View File

@@ -17,7 +17,6 @@ import (
"geekai/store/model"
"geekai/utils"
"geekai/utils/resp"
"github.com/shopspring/decimal"
"net/http"
"sync"
"time"
@@ -105,7 +104,7 @@ func (h *PaymentHandler) Pay(c *gin.Context) {
return
}
amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
amount := product.Discount
var payURL, returnURL, notifyURL string
switch data.PayWay {
case "alipay":

View File

@@ -0,0 +1,128 @@
package handler
import (
"fmt"
"geekai/core"
"geekai/store/model"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"strings"
"time"
)
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// OpenAI Realtime API Relay Server
type RealtimeHandler struct {
BaseHandler
}
func NewRealtimeHandler(server *core.AppServer, db *gorm.DB) *RealtimeHandler {
return &RealtimeHandler{BaseHandler{App: server, DB: db}}
}
func (h *RealtimeHandler) Connection(c *gin.Context) {
// 获取客户端请求中指定的子协议
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
md := c.Query("model")
userId := h.GetLoginUserId(c)
var user model.User
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
c.Abort()
return
}
// 将 HTTP 协议升级为 Websocket 协议
subProtocols := strings.Split(clientProtocols, ",")
ws, err := (&websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: subProtocols,
}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
defer ws.Close()
// 目前只针对 VIP 用户可以访问
if !user.Vip {
sendError(ws, "当前功能只针对 VIP 用户开放")
c.Abort()
return
}
var apiKey model.ApiKey
h.DB.Where("type", "realtime").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
if apiKey.Id == 0 {
sendError(ws, "管理员未配置 Realtime API KEY")
c.Abort()
return
}
apiURL := fmt.Sprintf("%s/v1/realtime?model=%s", apiKey.ApiURL, md)
// 连接到真实的后端服务器,传入相同的子协议
headers := http.Header{}
// 修正子协议内容
subProtocols[1] = "openai-insecure-api-key." + apiKey.Value
if clientProtocols != "" {
headers.Set("Sec-WebSocket-Protocol", strings.Join(subProtocols, ","))
}
backendConn, _, err := websocket.DefaultDialer.Dial(apiURL, headers)
if err != nil {
sendError(ws, "桥接后端 API 失败:"+err.Error())
c.Abort()
return
}
defer backendConn.Close()
// 确保协议一致性,如果失败返回
if ws.Subprotocol() != backendConn.Subprotocol() {
sendError(ws, "Websocket 子协议不匹配")
c.Abort()
return
}
// 更新API KEY 最后使用时间
h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix())
// 开始双向转发
errorChan := make(chan error, 2)
go relay(ws, backendConn, errorChan)
go relay(backendConn, ws, errorChan)
// 等待其中一个连接关闭
err = <-errorChan
logger.Infof("Relay ended: %v", err)
}
func relay(src, dst *websocket.Conn, errorChan chan error) {
for {
messageType, message, err := src.ReadMessage()
if err != nil {
errorChan <- err
return
}
err = dst.WriteMessage(messageType, message)
if err != nil {
errorChan <- err
return
}
}
}
func sendError(ws *websocket.Conn, message string) {
err := ws.WriteJSON(map[string]string{"event_id": "event_01", "type": "error", "error": message})
if err != nil {
logger.Error(err)
}
}

View File

@@ -250,18 +250,13 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
// 删除任务
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
tx.Delete(&job)
// 如果任务未完成,或者任务失败,则恢复用户算力
if job.Progress != 100 {
err := h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "stable-diffusion",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%s Err: %s", job.TaskId, job.ErrMsg),
Remark: fmt.Sprintf("任务失败退回算力。任务ID%d Err: %s", job.Id, job.ErrMsg),
})
if err != nil {
tx.Rollback()

View File

@@ -156,6 +156,9 @@ func (h *VideoHandler) List(c *gin.Context) {
continue
}
item.CreatedAt = v.CreatedAt.Unix()
if item.VideoURL == "" {
item.VideoURL = v.WaterURL
}
items = append(items, item)
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"strings"
)
// Websocket 连接处理 handler
@@ -37,7 +38,11 @@ func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *g
}
func (h *WebsocketHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
ws, err := (&websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: strings.Split(clientProtocols, ","),
}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()