add 'type' field for ChatModel, support Chat and Image model

This commit is contained in:
RockYang
2024-12-25 18:57:18 +08:00
parent cbf06eea24
commit acee2d9d81
35 changed files with 766 additions and 698 deletions

View File

@@ -8,6 +8,7 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
@@ -17,14 +18,13 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type DallJobHandler struct {
BaseHandler
redis *redis.Client
dallService *dalle.Service
uploader *oss.UploaderManager
userService *service.UserService
@@ -42,49 +42,49 @@ func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service,
}
}
func (h *DallJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return false
}
if user.Power < h.App.SysConfig.DallPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false
}
return true
}
// Image 创建一个绘画任务
func (h *DallJobHandler) Image(c *gin.Context) {
if !h.preCheck(c) {
return
}
var data types.DallTask
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
var chatModel model.ChatModel
if res := h.DB.Where("id = ?", data.ModelId).First(&chatModel); res.Error != nil {
resp.ERROR(c, "模型不存在")
return
}
// 检查用户剩余算力
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
if user.Power < chatModel.Power {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return
}
idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
task := types.DallTask{
ClientId: data.ClientId,
UserId: uint(userId),
ModelId: chatModel.Id,
ModelName: chatModel.Value,
Prompt: data.Prompt,
Quality: data.Quality,
Size: data.Size,
Style: data.Style,
Power: h.App.SysConfig.DallPower,
TranslateModelId: h.App.SysConfig.TranslateModelId,
Power: chatModel.Power,
}
job := model.DallJob{
UserId: uint(userId),
Prompt: data.Prompt,
Power: task.Power,
Power: chatModel.Power,
TaskInfo: utils.JsonEncode(task),
}
res := h.DB.Create(&job)
@@ -95,6 +95,17 @@ func (h *DallJobHandler) Image(c *gin.Context) {
task.Id = job.Id
h.dallService.PushTask(task)
// 扣减算力
err = h.userService.DecreasePower(int(user.Id), chatModel.Power, model.PowerLog{
Type: types.PowerConsume,
Model: chatModel.Value,
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
})
if err != nil {
resp.ERROR(c, "error with decrease power: "+err.Error())
return
}
resp.SUCCESS(c)
}
@@ -210,3 +221,25 @@ func (h *DallJobHandler) Publish(c *gin.Context) {
resp.SUCCESS(c)
}
func (h *DallJobHandler) GetModels(c *gin.Context) {
var models []model.ChatModel
err := h.DB.Where("type", "img").Where("enabled", true).Find(&models).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
var modelVos []vo.ChatModel
for _, v := range models {
var modelVo vo.ChatModel
err := utils.CopyObject(v, &modelVo)
if err != nil {
continue
}
modelVo.Id = v.Id
modelVos = append(modelVos, modelVo)
}
resp.SUCCESS(c, modelVos)
}