feat: add authorization for MidJourney function calls

This commit is contained in:
RockYang
2023-08-16 23:16:44 +08:00
parent c8998ba294
commit fab43097dc
12 changed files with 81 additions and 58 deletions

View File

@@ -73,6 +73,7 @@ func (h *UserHandler) Save(c *gin.Context) {
Mobile string `json:"mobile"`
Nickname string `json:"nickname"`
Calls int `json:"calls"`
ImgCalls int `json:"img_calls"`
ChatRoles []string `json:"chat_roles"`
ExpiredTime string `json:"expired_time"`
Status bool `json:"status"`
@@ -91,6 +92,7 @@ func (h *UserHandler) Save(c *gin.Context) {
"nickname": data.Nickname,
"mobile": data.Mobile,
"calls": data.Calls,
"img_calls": data.ImgCalls,
"status": data.Status,
"chat_roles_json": utils.JsonEncode(data.ChatRoles),
"expired_time": utils.Str2stamp(data.ExpiredTime),

View File

@@ -326,44 +326,53 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
} // end for
if functionCall { // 调用函数完成任务
logger.Info("函数名称:", functionName)
var params map[string]interface{}
_ = utils.JsonDecode(strings.Join(arguments, ""), &params)
logger.Info("函数参数:", params)
f := h.App.Functions[functionName]
data, err := f.Invoke(params)
if err != nil {
msg := "调用函数出错:" + err.Error()
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: msg,
})
contents = append(contents, msg)
} else {
content := data
if functionName == types.FuncMidJourney {
key := utils.Sha256(data)
//logger.Info(data, ",", key)
// add task for MidJourney
h.App.MjTaskClients.Put(key, ws)
task := types.MjTask{
UserId: userVo.Id,
RoleId: role.Id,
Icon: "/images/avatar/mid_journey.png",
ChatId: session.ChatId,
}
err := h.leveldb.Put(types.TaskStorePrefix+key, task)
if err != nil {
logger.Error("error with store MidJourney task: ", err)
}
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
}
logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: content,
})
contents = append(contents, content)
// for creating image, check if the user's img_calls > 0
if functionName == types.FuncMidJourney && userVo.ImgCalls <= 0 {
utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
utils.ReplyMessage(ws, "![](/images/wx.png)")
} else {
f := h.App.Functions[functionName]
data, err := f.Invoke(params)
if err != nil {
msg := "调用函数出错:" + err.Error()
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: msg,
})
contents = append(contents, msg)
} else {
content := data
if functionName == types.FuncMidJourney {
key := utils.Sha256(data)
logger.Debug(data, ",", key)
// add task for MidJourney
h.App.MjTaskClients.Put(key, ws)
task := types.MjTask{
UserId: userVo.Id,
RoleId: role.Id,
Icon: "/images/avatar/mid_journey.png",
ChatId: session.ChatId,
}
err := h.leveldb.Put(types.TaskStorePrefix+key, task)
if err != nil {
logger.Error("error with store MidJourney task: ", err)
}
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
// update user's img_calls
h.db.Model(&user).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: content,
})
contents = append(contents, content)
}
}
}
@@ -371,10 +380,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
if len(contents) > 0 {
// 更新用户的对话次数
if userVo.ChatConfig.ApiKey == "" { // 如果用户使用的是自己绑定的 API KEY 则不扣减对话次数
res := h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
if res.Error != nil {
return res.Error
}
h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
}
if message.Role == "" {

View File

@@ -109,6 +109,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
job.UserId = task.UserId
job.ChatId = task.ChatId
job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId
job.Content = data.Content
job.Prompt = data.Prompt
job.Image = utils.JsonEncode(data.Image)

View File

@@ -108,7 +108,8 @@ func (h *UserHandler) Register(c *gin.Context) {
Model: h.App.ChatConfig.Model,
ApiKey: "",
}),
Calls: h.App.SysConfig.UserInitCalls,
Calls: h.App.SysConfig.UserInitCalls,
ImgCalls: h.App.SysConfig.InitImgCalls,
}
res = h.db.Create(&user)
if res.Error != nil {