feat: add authorization for MidJourney function calls

This commit is contained in:
RockYang 2023-08-16 23:16:44 +08:00
parent c8998ba294
commit fab43097dc
12 changed files with 81 additions and 58 deletions

View File

@ -93,6 +93,7 @@ type SystemConfig struct {
AdminTitle string `json:"admin_title"` AdminTitle string `json:"admin_title"`
Models []string `json:"models"` Models []string `json:"models"`
UserInitCalls int `json:"user_init_calls"` // 新用户注册默认总送多少次调用 UserInitCalls int `json:"user_init_calls"` // 新用户注册默认总送多少次调用
InitImgCalls int `json:"init_img_calls"`
EnabledRegister bool `json:"enabled_register"` EnabledRegister bool `json:"enabled_register"`
EnabledMsgService bool `json:"enabled_msg_service"` EnabledMsgService bool `json:"enabled_msg_service"`
} }

View File

@ -73,6 +73,7 @@ func (h *UserHandler) Save(c *gin.Context) {
Mobile string `json:"mobile"` Mobile string `json:"mobile"`
Nickname string `json:"nickname"` Nickname string `json:"nickname"`
Calls int `json:"calls"` Calls int `json:"calls"`
ImgCalls int `json:"img_calls"`
ChatRoles []string `json:"chat_roles"` ChatRoles []string `json:"chat_roles"`
ExpiredTime string `json:"expired_time"` ExpiredTime string `json:"expired_time"`
Status bool `json:"status"` Status bool `json:"status"`
@ -91,6 +92,7 @@ func (h *UserHandler) Save(c *gin.Context) {
"nickname": data.Nickname, "nickname": data.Nickname,
"mobile": data.Mobile, "mobile": data.Mobile,
"calls": data.Calls, "calls": data.Calls,
"img_calls": data.ImgCalls,
"status": data.Status, "status": data.Status,
"chat_roles_json": utils.JsonEncode(data.ChatRoles), "chat_roles_json": utils.JsonEncode(data.ChatRoles),
"expired_time": utils.Str2stamp(data.ExpiredTime), "expired_time": utils.Str2stamp(data.ExpiredTime),

View File

@ -326,44 +326,53 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
} // end for } // end for
if functionCall { // 调用函数完成任务 if functionCall { // 调用函数完成任务
logger.Info("函数名称:", functionName)
var params map[string]interface{} var params map[string]interface{}
_ = utils.JsonDecode(strings.Join(arguments, ""), &params) _ = utils.JsonDecode(strings.Join(arguments, ""), &params)
logger.Info("函数参数:", params) logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
f := h.App.Functions[functionName]
data, err := f.Invoke(params)
if err != nil {
msg := "调用函数出错:" + err.Error()
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: msg,
})
contents = append(contents, msg)
} else {
content := data
if functionName == types.FuncMidJourney {
key := utils.Sha256(data)
//logger.Info(data, ",", key)
// add task for MidJourney
h.App.MjTaskClients.Put(key, ws)
task := types.MjTask{
UserId: userVo.Id,
RoleId: role.Id,
Icon: "/images/avatar/mid_journey.png",
ChatId: session.ChatId,
}
err := h.leveldb.Put(types.TaskStorePrefix+key, task)
if err != nil {
logger.Error("error with store MidJourney task: ", err)
}
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
}
utils.ReplyChunkMessage(ws, types.WsMessage{ // for creating image, check if the user's img_calls > 0
Type: types.WsMiddle, if functionName == types.FuncMidJourney && userVo.ImgCalls <= 0 {
Content: content, utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
}) utils.ReplyMessage(ws, "![](/images/wx.png)")
contents = append(contents, content) } else {
f := h.App.Functions[functionName]
data, err := f.Invoke(params)
if err != nil {
msg := "调用函数出错:" + err.Error()
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: msg,
})
contents = append(contents, msg)
} else {
content := data
if functionName == types.FuncMidJourney {
key := utils.Sha256(data)
logger.Debug(data, ",", key)
// add task for MidJourney
h.App.MjTaskClients.Put(key, ws)
task := types.MjTask{
UserId: userVo.Id,
RoleId: role.Id,
Icon: "/images/avatar/mid_journey.png",
ChatId: session.ChatId,
}
err := h.leveldb.Put(types.TaskStorePrefix+key, task)
if err != nil {
logger.Error("error with store MidJourney task: ", err)
}
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
// update user's img_calls
h.db.Model(&user).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: content,
})
contents = append(contents, content)
}
} }
} }
@ -371,10 +380,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
if len(contents) > 0 { if len(contents) > 0 {
// 更新用户的对话次数 // 更新用户的对话次数
if userVo.ChatConfig.ApiKey == "" { // 如果用户使用的是自己绑定的 API KEY 则不扣减对话次数 if userVo.ChatConfig.ApiKey == "" { // 如果用户使用的是自己绑定的 API KEY 则不扣减对话次数
res := h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1)) h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
if res.Error != nil {
return res.Error
}
} }
if message.Role == "" { if message.Role == "" {

View File

@ -109,6 +109,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
job.UserId = task.UserId job.UserId = task.UserId
job.ChatId = task.ChatId job.ChatId = task.ChatId
job.MessageId = data.MessageId job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId
job.Content = data.Content job.Content = data.Content
job.Prompt = data.Prompt job.Prompt = data.Prompt
job.Image = utils.JsonEncode(data.Image) job.Image = utils.JsonEncode(data.Image)

View File

@ -108,7 +108,8 @@ func (h *UserHandler) Register(c *gin.Context) {
Model: h.App.ChatConfig.Model, Model: h.App.ChatConfig.Model,
ApiKey: "", ApiKey: "",
}), }),
Calls: h.App.SysConfig.UserInitCalls, Calls: h.App.SysConfig.UserInitCalls,
ImgCalls: h.App.SysConfig.InitImgCalls,
} }
res = h.db.Create(&user) res = h.db.Create(&user)
if res.Error != nil { if res.Error != nil {

View File

@ -3,15 +3,16 @@ package model
import "time" import "time"
type MidJourneyJob struct { type MidJourneyJob struct {
Id uint `gorm:"primarykey;column:id"` Id uint `gorm:"primarykey;column:id"`
UserId uint UserId uint
ChatId string ChatId string
MessageId string MessageId string
Hash string ReferenceId string
Content string Hash string
Prompt string Content string
Image string Prompt string
CreatedAt time.Time Image string
CreatedAt time.Time
} }
func (MidJourneyJob) TableName() string { func (MidJourneyJob) TableName() string {

View File

@ -10,6 +10,7 @@ type User struct {
Salt string // 密码盐 Salt string // 密码盐
Tokens int64 // 剩余tokens Tokens int64 // 剩余tokens
Calls int // 剩余对话次数 Calls int // 剩余对话次数
ImgCalls int // 剩余绘图次数
ChatConfig string `gorm:"column:chat_config_json"` // 聊天配置 json ChatConfig string `gorm:"column:chat_config_json"` // 聊天配置 json
ChatRoles string `gorm:"column:chat_roles_json"` // 聊天角色 ChatRoles string `gorm:"column:chat_roles_json"` // 聊天角色
ExpiredTime int64 // 账户到期时间 ExpiredTime int64 // 账户到期时间

View File

@ -8,9 +8,10 @@ type User struct {
Mobile string `json:"mobile"` Mobile string `json:"mobile"`
Nickname string `json:"nickname"` Nickname string `json:"nickname"`
Avatar string `json:"avatar"` Avatar string `json:"avatar"`
Salt string `json:"salt"` // 密码盐 Salt string `json:"salt"` // 密码盐
Tokens int64 `json:"tokens"` // 剩余tokens Tokens int64 `json:"tokens"` // 剩余tokens
Calls int `json:"calls"` // 剩余对话次数 Calls int `json:"calls"` // 剩余对话次数
ImgCalls int `json:"img_calls"`
ChatConfig types.ChatConfig `json:"chat_config"` // 聊天配置 ChatConfig types.ChatConfig `json:"chat_config"` // 聊天配置
ChatRoles []string `json:"chat_roles"` // 聊天角色集合 ChatRoles []string `json:"chat_roles"` // 聊天角色集合
ExpiredTime int64 `json:"expired_time"` // 账户到期时间 ExpiredTime int64 `json:"expired_time"` // 账户到期时间

View File

@ -31,4 +31,7 @@ ALTER TABLE `chatgpt_mj_jobs`
-- --
ALTER TABLE `chatgpt_mj_jobs` ALTER TABLE `chatgpt_mj_jobs`
MODIFY `id` int NOT NULL AUTO_INCREMENT; MODIFY `id` int NOT NULL AUTO_INCREMENT;
COMMIT;
ALTER TABLE `chatgpt_mj_jobs` ADD `reference_id` CHAR(40) NULL DEFAULT NULL COMMENT '引用消息 ID' AFTER `message_id`;
ALTER TABLE `chatgpt_users` ADD `img_calls` INT NOT NULL DEFAULT '0' COMMENT '剩余绘图次数' AFTER `calls`;

View File

@ -6,6 +6,7 @@
<el-table-column prop="username" label="用户名"/> <el-table-column prop="username" label="用户名"/>
<el-table-column prop="tx_id" label="转账单号"/> <el-table-column prop="tx_id" label="转账单号"/>
<el-table-column prop="amount" label="转账金额"/> <el-table-column prop="amount" label="转账金额"/>
<el-table-column prop="remark" label="备注"/>
<el-table-column label="转账时间"> <el-table-column label="转账时间">
<template #default="scope"> <template #default="scope">
@ -27,11 +28,10 @@
</template> </template>
<script setup> <script setup>
import {reactive, ref} from "vue"; import {ref} from "vue";
import {httpGet, httpPost} from "@/utils/http"; import {httpGet} from "@/utils/http";
import {ElMessage} from "element-plus"; import {ElMessage} from "element-plus";
import {dateFormat, disabledDate, removeArrayItem} from "@/utils/libs"; import {dateFormat} from "@/utils/libs";
import {Plus} from "@element-plus/icons-vue";
// //
const items = ref([]) const items = ref([])

View File

@ -9,9 +9,12 @@
<el-form-item label="控制台标题" prop="admin_title"> <el-form-item label="控制台标题" prop="admin_title">
<el-input v-model="system['admin_title']"/> <el-input v-model="system['admin_title']"/>
</el-form-item> </el-form-item>
<el-form-item label="注册赠送次数" prop="init_calls"> <el-form-item label="赠送对话次数" prop="init_calls">
<el-input v-model.number="system['user_init_calls']" placeholder="新用户注册赠送对话次数"/> <el-input v-model.number="system['user_init_calls']" placeholder="新用户注册赠送对话次数"/>
</el-form-item> </el-form-item>
<el-form-item label="赠送绘图次数" prop="init_calls">
<el-input v-model.number="system['init_img_calls']" placeholder="新用户注册赠送绘图次数"/>
</el-form-item>
<el-form-item label="短信验证服务" prop="enabled_msg_service"> <el-form-item label="短信验证服务" prop="enabled_msg_service">
<el-switch v-model="system['enabled_msg_service']"/> <el-switch v-model="system['enabled_msg_service']"/>
</el-form-item> </el-form-item>

View File

@ -82,6 +82,9 @@
<el-form-item label="提问次数:" prop="calls"> <el-form-item label="提问次数:" prop="calls">
<el-input v-model.number="user.calls" autocomplete="off" placeholder="0"/> <el-input v-model.number="user.calls" autocomplete="off" placeholder="0"/>
</el-form-item> </el-form-item>
<el-form-item label="绘图次数:" prop="img_calls">
<el-input v-model.number="user['img_calls']" autocomplete="off" placeholder="0"/>
</el-form-item>
<el-form-item label="有效期:" prop="expired_time"> <el-form-item label="有效期:" prop="expired_time">
<el-date-picker <el-date-picker