fix conflicts

This commit is contained in:
RockYang
2024-03-12 18:03:24 +08:00
239 changed files with 602 additions and 452 deletions

View File

@@ -48,7 +48,7 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
Enabled: data.Enabled,
SortNum: data.SortNum,
Open: data.Open,
Weight: data.Weight}
Power: data.Weight}
item.Id = data.Id
if item.Id > 0 {
item.CreatedAt = time.Unix(data.CreatedAt, 0)

View File

@@ -32,8 +32,7 @@ func (h *ProductHandler) Save(c *gin.Context) {
Discount float64 `json:"discount"`
Enabled bool `json:"enabled"`
Days int `json:"days"`
Calls int `json:"calls"`
ImgCalls int `json:"img_calls"`
Power int `json:"power"`
CreatedAt int64 `json:"created_at"`
}
if err := c.ShouldBindJSON(&data); err != nil {
@@ -46,8 +45,7 @@ func (h *ProductHandler) Save(c *gin.Context) {
Price: data.Price,
Discount: data.Discount,
Days: data.Days,
Calls: data.Calls,
ImgCalls: data.ImgCalls,
Power: data.Power,
Enabled: data.Enabled}
item.Id = data.Id
if item.Id > 0 {

View File

@@ -66,8 +66,6 @@ func (h *UserHandler) Save(c *gin.Context) {
Id uint `json:"id"`
Password string `json:"password"`
Username string `json:"username"`
Calls int `json:"calls"`
ImgCalls int `json:"img_calls"`
ChatRoles []string `json:"chat_roles"`
ChatModels []string `json:"chat_models"`
ExpiredTime string `json:"expired_time"`
@@ -86,8 +84,6 @@ func (h *UserHandler) Save(c *gin.Context) {
// 此处需要用 map 更新,用结构体无法更新 0 值
res = h.db.Model(&user).Updates(map[string]interface{}{
"username": data.Username,
"calls": data.Calls,
"img_calls": data.ImgCalls,
"status": data.Status,
"vip": data.Vip,
"chat_roles_json": utils.JsonEncode(data.ChatRoles),
@@ -113,8 +109,6 @@ func (h *UserHandler) Save(c *gin.Context) {
types.ChatGLM: "",
},
}),
Calls: data.Calls,
ImgCalls: data.ImgCalls,
}
res = h.db.Create(&u)
_ = utils.CopyObject(u, &userVo)

View File

@@ -19,7 +19,7 @@ import (
// 微软 Azure 模型消息发送实现
func (h *ChatHandler) sendAzureMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -103,8 +103,6 @@ func (h *ChatHandler) sendAzureMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
@@ -145,8 +143,8 @@ func (h *ChatHandler) sendAzureMessage(
}
// 计算本次对话消耗的总 token 数量
totalTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens += getTotalTokens(req)
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
@@ -155,7 +153,7 @@ func (h *ChatHandler) sendAzureMessage(
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
Tokens: replyTokens,
UseContext: true,
Model: req.Model,
}
@@ -166,8 +164,8 @@ func (h *ChatHandler) sendAzureMessage(
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
}
// 保存当前会话

View File

@@ -36,7 +36,7 @@ type baiduResp struct {
// 百度文心一言消息发送实现
func (h *ChatHandler) sendBaiduMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -128,9 +128,6 @@ func (h *ChatHandler) sendBaiduMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -171,8 +168,8 @@ func (h *ChatHandler) sendBaiduMessage(
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
@@ -190,8 +187,8 @@ func (h *ChatHandler) sendBaiduMessage(
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
}
// 保存当前会话

View File

@@ -111,7 +111,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session.Model = types.ChatModel{
Id: chatModel.Id,
Value: chatModel.Value,
Weight: chatModel.Weight,
Power: chatModel.Power,
Platform: types.Platform(chatModel.Platform)}
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
var chatRole model.ChatRole
@@ -207,13 +207,13 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return nil
}
if userVo.Calls < session.Model.Weight {
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余对话次数(%d已不足以支付当前模型的单次对话需要消耗的对话额度%d", userVo.Calls, session.Model.Weight))
if userVo.Power < session.Model.Power {
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余对话次数(%d已不足以支付当前模型的单次对话需要消耗的对话额度%d", userVo.Power, session.Model.Power))
utils.ReplyMessage(ws, ErrImg)
return nil
}
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
if userVo.Power <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者充值点卡继续对话!")
utils.ReplyMessage(ws, ErrImg)
return nil
@@ -224,6 +224,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
utils.ReplyMessage(ws, ErrImg)
return nil
}
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
if promptTokens > types.GetModelMaxToken(session.Model.Value) {
utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
return nil
}
var req = types.ApiRequest{
Model: session.Model.Value,
Stream: true,
@@ -252,7 +260,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
var tools = make([]interface{}, 0)
var functions = make([]interface{}, 0)
for _, v := range items {
var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters)
@@ -270,20 +277,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
"required": required,
},
})
functions = append(functions, gin.H{
"name": v.Name,
"description": v.Description,
"parameters": parameters,
"required": required,
})
}
//if len(tools) > 0 {
// req.Tools = tools
// req.ToolChoice = "auto"
//}
if len(functions) > 0 {
req.Functions = functions
if len(tools) > 0 {
req.Tools = tools
req.ToolChoice = "auto"
}
case types.XunFei:
@@ -301,40 +299,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
// 加载聊天上下文
var chatCtx []interface{}
chatCtx := make([]types.Message, 0)
messages := make([]types.Message, 0)
if h.App.ChatConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.App.ChatContexts.Get(session.ChatId)
messages = h.App.ChatContexts.Get(session.ChatId)
} else {
// calculate the tokens of current request, to prevent to exceeding the max tokens num
tokens := req.MaxTokens
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks
// loading the role context
var messages []types.Message
err := utils.JsonDecode(role.Context, &messages)
if err == nil {
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
if tokens+tks >= types.GetModelMaxToken(req.Model) {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
}
// loading recent chat history as chat context
_ = utils.JsonDecode(role.Context, &messages)
if chatConfig.ContextDeep > 0 {
var historyMessages []model.ChatMessage
res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages)
res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
if res.Error == nil {
for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i]
if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
break
}
tokens += msg.Tokens
ms := types.Message{Role: "user", Content: msg.Content}
if msg.Type == types.ReplyMsg {
ms.Role = "assistant"
@@ -344,6 +321,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
}
}
// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
// MaxContextLength = Response + Tool + Prompt + Context
tokens := req.MaxTokens // 最大响应长度
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks + promptTokens
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
// 上下文 token 超出了模型的最大上下文长度
if tokens+tks >= types.GetModelMaxToken(req.Model) {
break
}
// 上下文的深度超出了模型的最大上下文深度
if len(chatCtx) >= h.App.ChatConfig.ContextDeep {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
logger.Debugf("聊天上下文:%+v", chatCtx)
}
reqMgs := make([]interface{}, 0)
@@ -533,23 +533,28 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
return client.Do(request)
}
// 扣减用户的对话次数
func (h *ChatHandler) subUserCalls(userVo vo.User, session *types.ChatSession) {
// 仅当用户没有导入自己的 API KEY 时才进行扣减
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
num := 1
if session.Model.Weight > 0 {
num = session.Model.Weight
}
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", num))
// 扣减用户算力
func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) {
power := 1
if session.Model.Power > 0 {
power = session.Model.Power
}
res := h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
if res.Error == nil {
// 记录算力消费日志
h.db.Debug().Create(&model.PowerLog{
UserId: userVo.Id,
Username: userVo.Username,
Type: types.PowerConsume,
Amount: power,
Mark: types.PowerSub,
Balance: userVo.Power - power,
Model: session.Model.Value,
Remark: fmt.Sprintf("提问长度:%d回复长度%d", promptTokens, replyTokens),
CreatedAt: time.Now(),
})
}
}
func (h *ChatHandler) incUserTokenFee(userId uint, tokens int) {
h.db.Model(&model.User{}).Where("id = ?", userId).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", tokens))
h.db.Model(&model.User{}).Where("id = ?", userId).
UpdateColumn("tokens", gorm.Expr("tokens + ?", tokens))
}
// 将AI回复消息中生成的图片链接下载到本地

View File

@@ -20,7 +20,7 @@ import (
// 清华大学 ChatGML 消息发送实现
func (h *ChatHandler) sendChatGLMMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -107,9 +107,6 @@ func (h *ChatHandler) sendChatGLMMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -150,8 +147,8 @@ func (h *ChatHandler) sendChatGLMMessage(
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
@@ -169,8 +166,9 @@ func (h *ChatHandler) sendChatGLMMessage(
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
}
// 保存当前会话

View File

@@ -20,7 +20,7 @@ import (
// OPenAI 消息发送实现
func (h *ChatHandler) sendOpenAiMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -46,8 +46,10 @@ func (h *ChatHandler) sendOpenAiMessage(
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
all, _ := io.ReadAll(response.Body)
logger.Error(string(all))
if response.Body != nil {
all, _ := io.ReadAll(response.Body)
logger.Error(string(all))
}
return err
} else {
defer response.Body.Close()
@@ -171,9 +173,6 @@ func (h *ChatHandler) sendOpenAiMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -218,16 +217,16 @@ func (h *ChatHandler) sendOpenAiMessage(
}
// 计算本次对话消耗的总 token 数量
var totalTokens = 0
var replyTokens = 0
if toolCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(function.Name, req.Model)
totalTokens += tokens
replyTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
totalTokens += tokens
replyTokens += tokens
} else {
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
totalTokens += getTotalTokens(req)
replyTokens += getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
@@ -236,7 +235,7 @@ func (h *ChatHandler) sendOpenAiMessage(
Type: types.ReplyMsg,
Icon: role.Icon,
Content: h.extractImgUrl(message.Content),
Tokens: totalTokens,
Tokens: replyTokens,
UseContext: useContext,
Model: req.Model,
}
@@ -247,8 +246,8 @@ func (h *ChatHandler) sendOpenAiMessage(
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
}
// 保存当前会话

View File

@@ -31,7 +31,7 @@ type qWenResp struct {
// 通义千问消息发送实现
func (h *ChatHandler) sendQWenMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -128,9 +128,6 @@ func (h *ChatHandler) sendQWenMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -171,8 +168,8 @@ func (h *ChatHandler) sendQWenMessage(
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
@@ -190,8 +187,9 @@ func (h *ChatHandler) sendQWenMessage(
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
}
// 保存当前会话

View File

@@ -58,7 +58,7 @@ var Model2URL = map[string]string{
// 科大讯飞消息发送实现
func (h *ChatHandler) sendXunFeiMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@@ -166,9 +166,6 @@ func (h *ChatHandler) sendXunFeiMessage(
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
h.subUserCalls(userVo, session)
if message.Role == "" {
message.Role = "assistant"
}
@@ -209,8 +206,8 @@ func (h *ChatHandler) sendXunFeiMessage(
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
@@ -228,8 +225,9 @@ func (h *ChatHandler) sendXunFeiMessage(
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.incUserTokenFee(userVo.Id, totalTokens)
// 更新用户算力
h.subUserPower(userVo, session, promptToken, replyTokens)
}
// 保存当前会话

View File

@@ -192,7 +192,6 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
}
logger.Debugf("绘画参数:%+v", params)
// check img calls
var user model.User
tx := h.db.Where("id = ?", params["user_id"]).First(&user)
if tx.Error != nil {
@@ -200,8 +199,8 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
return
}
if user.ImgCalls <= 0 {
resp.ERROR(c, "当前用户的绘图次数额度不足")
if user.Power < h.App.SysConfig.DallPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画")
return
}
@@ -274,8 +273,22 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
}
content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n![](%s)\n", prompt, imgURL)
// update user's img_calls
h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// 更新用户算力
tx = h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("power", gorm.Expr("power - ?", h.App.SysConfig.DallPower))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
h.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: h.App.SysConfig.DallPower,
Balance: user.Power - h.App.SysConfig.DallPower,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: "",
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c, content)
}

View File

@@ -48,8 +48,8 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
return false
}
if user.ImgCalls <= 0 {
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值")
if user.Power < h.App.SysConfig.MjPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画")
return false
}
@@ -160,13 +160,18 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
TaskId: taskId,
Progress: 0,
Prompt: prompt,
Power: h.App.SysConfig.MjPower,
CreatedAt: time.Now(),
}
opt := "绘图"
if data.TaskType == types.TaskBlend.String() {
data.Prompt = "融图:" + strings.Join(data.ImgArr, ",")
job.Prompt = "融图:" + strings.Join(data.ImgArr, ",")
opt = "融图"
} else if data.TaskType == types.TaskSwapFace.String() {
data.Prompt = "换脸:" + strings.Join(data.ImgArr, ",")
job.Prompt = "换脸:" + strings.Join(data.ImgArr, ",")
opt = "换脸"
}
if res := h.db.Create(&job); res.Error != nil || res.RowsAffected == 0 {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
@@ -187,8 +192,23 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
_ = client.Send([]byte("Task Updated"))
}
// update user's img calls
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// update user's power
tx := h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := utils.GetLoginUser(c, h.db)
h.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("%s操作任务ID%s", opt, job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c)
}
@@ -276,6 +296,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
TaskId: taskId,
Progress: 0,
Prompt: data.Prompt,
Power: h.App.SysConfig.MjPower,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error != nil || res.RowsAffected == 0 {
@@ -300,8 +321,23 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
_ = client.Send([]byte("Task Updated"))
}
// update user's img calls
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// update user's power
tx := h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := utils.GetLoginUser(c, h.db)
h.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "mid-journey",
Remark: fmt.Sprintf("Variation 操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c)
}
@@ -368,13 +404,6 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
continue
}
// 失败的任务直接删除
if job.Progress == -1 {
h.db.Delete(&model.MidJourneyJob{Id: job.Id})
jobs = append(jobs, job)
continue
}
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
// discord 服务器图片需要使用代理转发图片数据流
if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") {

View File

@@ -202,8 +202,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
// 创建订单
remark := types.OrderRemark{
Days: product.Days,
Calls: product.Calls,
ImgCalls: product.ImgCalls,
Power: product.Power,
Name: product.Name,
Price: product.Price,
Discount: product.Discount,
@@ -313,20 +312,16 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
if remark.Days > 0 { // 只延期 VIP不增加调用次数
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
} else { // 充值点卡,直接增加次数即可
user.Calls += remark.Calls
user.ImgCalls += remark.ImgCalls
user.Power += remark.Power
}
} else { // 非 VIP 用户
if remark.Days > 0 { // vip 套餐days > 0, calls == 0
if remark.Days > 0 { // vip 套餐days > 0, power == 0
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
user.Calls += h.App.SysConfig.VipMonthCalls
user.ImgCalls += h.App.SysConfig.VipMonthImgCalls
user.Power += h.App.SysConfig.VipMonthPower
user.Vip = true
} else { //点卡days == 0, calls > 0
user.Calls += remark.Calls
user.ImgCalls += remark.ImgCalls
user.Power += remark.Power
}
}

View File

@@ -7,11 +7,13 @@ import (
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"math"
"strings"
"sync"
"time"
)
type RewardHandler struct {
@@ -30,7 +32,6 @@ func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler {
func (h *RewardHandler) Verify(c *gin.Context) {
var data struct {
TxId string `json:"tx_id"`
Type string `json:"type"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -63,16 +64,11 @@ func (h *RewardHandler) Verify(c *gin.Context) {
tx := h.db.Begin()
exchange := vo.RewardExchange{}
if data.Type == "chat" {
calls := math.Ceil(item.Amount / h.App.SysConfig.ChatCallPrice)
exchange.Calls = int(calls)
res = h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls + ?", calls))
} else if data.Type == "img" {
calls := math.Ceil(item.Amount / h.App.SysConfig.ImgCallPrice)
exchange.ImgCalls = int(calls)
res = h.db.Model(&user).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", calls))
}
power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice)
exchange.Power = int(power)
res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power))
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败!")
return
}
@@ -81,13 +77,25 @@ func (h *RewardHandler) Verify(c *gin.Context) {
item.Status = true
item.UserId = user.Id
item.Exchange = utils.JsonEncode(exchange)
res = h.db.Updates(&item)
res = tx.Updates(&item)
if res.Error != nil {
tx.Rollback()
resp.ERROR(c, "更新数据库失败!")
return
}
// 记录算力充值日志
h.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerReward,
Amount: exchange.Power,
Balance: user.Power + exchange.Power,
Mark: types.PowerAdd,
Model: "",
Remark: fmt.Sprintf("众筹充值算力,金额:%f价格%f", item.Amount, h.App.SysConfig.PowerPrice),
CreatedAt: time.Now(),
})
tx.Commit()
resp.SUCCESS(c)

View File

@@ -72,8 +72,8 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
return false
}
if user.ImgCalls <= 0 {
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值")
if user.Power < h.App.SysConfig.SdPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画")
return false
}
@@ -140,6 +140,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
Params: utils.JsonEncode(params),
Prompt: data.Prompt,
Progress: 0,
Power: h.App.SysConfig.SdPower,
CreatedAt: time.Now(),
}
res := h.db.Create(&job)
@@ -162,8 +163,23 @@ func (h *SdJobHandler) Image(c *gin.Context) {
_ = client.Send([]byte("Task Updated"))
}
// update user's img calls
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// update user's power
tx := h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := utils.GetLoginUser(c, h.db)
h.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: "stable-diffusion",
Remark: fmt.Sprintf("绘图操作任务ID%s", job.TaskId),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c)
}
@@ -232,18 +248,7 @@ func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int,
continue
}
if job.Progress == -1 {
h.db.Delete(&model.SdJob{Id: job.Id})
}
if item.Progress < 100 {
// 5 分钟还没完成的任务直接删除
if time.Now().Sub(item.CreatedAt) > time.Minute*5 {
h.db.Delete(&item)
// 退回绘图次数
h.db.Model(&model.User{}).Where("id = ?", item.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
continue
}
// 正在运行中任务使用代理访问图片
image, err := utils.DownloadImage(item.ImgURL, "")
if err == nil {

View File

@@ -102,8 +102,7 @@ func (h *UserHandler) Register(c *gin.Context) {
types.ChatGLM: "",
},
}),
Calls: h.App.SysConfig.InitChatCalls,
ImgCalls: h.App.SysConfig.InitImgCalls,
Power: h.App.SysConfig.InitPower,
}
res = h.db.Create(&user)
@@ -117,11 +116,8 @@ func (h *UserHandler) Register(c *gin.Context) {
if data.InviteCode != "" {
// 增加邀请数量
h.db.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
if h.App.SysConfig.InviteChatCalls > 0 {
h.db.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("calls", gorm.Expr("calls + ?", h.App.SysConfig.InviteChatCalls))
}
if h.App.SysConfig.InviteImgCalls > 0 {
h.db.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", h.App.SysConfig.InviteImgCalls))
if h.App.SysConfig.InvitePower > 0 {
h.db.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower))
}
// 添加邀请记录
@@ -130,7 +126,7 @@ func (h *UserHandler) Register(c *gin.Context) {
UserId: user.Id,
Username: user.Username,
InviteCode: inviteCode.Code,
Reward: utils.JsonEncode(types.InviteReward{ChatCalls: h.App.SysConfig.InviteChatCalls, ImgCalls: h.App.SysConfig.InviteImgCalls}),
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
})
}
@@ -254,10 +250,7 @@ type userProfile struct {
Username string `json:"username"`
Avatar string `json:"avatar"`
ChatConfig types.UserChatConfig `json:"chat_config"`
Calls int `json:"calls"`
ImgCalls int `json:"img_calls"`
TotalTokens int64 `json:"total_tokens"`
Tokens int `json:"tokens"`
Power int `json:"power"`
ExpiredTime int64 `json:"expired_time"`
Vip bool `json:"vip"`
}