diff --git a/api/core/types/function.go b/api/core/types/function.go index c422118d..dd0a6e4e 100644 --- a/api/core/types/function.go +++ b/api/core/types/function.go @@ -85,15 +85,27 @@ var InnerFunctions = []Function{ Type: "string", Description: "绘画内容描述,提示词,如果该参数中有中文的话,则需要翻译成英文", }, - "ar": { + "--ar": { Type: "string", Description: "图片长宽比,默认值 16:9", }, - "niji": { + "--niji": { Type: "string", Description: "动漫模型版本,默认值空", }, - "v": { + "--s": { + Type: "string", + Description: "风格,stylize", + }, + "--seed": { + Type: "string", + Description: "随机种子", + }, + "--no": { + Type: "string", + Description: "负面提示词,指定不要什么元素或者风格,如果该参数中有中文的话,则需要翻译成英文", + }, + "--v": { Type: "string", Description: "模型版本,默认值: 5.2", }, diff --git a/api/handler/azure_handler.go b/api/handler/azure_handler.go index 39fc8b3c..9238036c 100644 --- a/api/handler/azure_handler.go +++ b/api/handler/azure_handler.go @@ -220,17 +220,17 @@ func (h *ChatHandler) sendAzureMessage( logger.Error("failed to save prompt history message: ", res.Error) } - // for reply // 计算本次对话消耗的总 token 数量 - var replyToken = 0 - if functionCall { // 函数名 + 参数 token + var totalTokens = 0 + if functionCall { // prompt + 函数名 + 参数 token tokens, _ := utils.CalcTokens(functionName, req.Model) - replyToken += tokens + totalTokens += tokens tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model) - replyToken += tokens + totalTokens += tokens } else { - replyToken, _ = utils.CalcTokens(message.Content, req.Model) + totalTokens, _ = utils.CalcTokens(message.Content, req.Model) } + totalTokens += getTotalTokens(req) historyReplyMsg := model.HistoryMessage{ UserId: userVo.Id, @@ -239,7 +239,7 @@ func (h *ChatHandler) sendAzureMessage( Type: types.ReplyMsg, Icon: role.Icon, Content: message.Content, - Tokens: replyToken, + Tokens: totalTokens, UseContext: useContext, } historyReplyMsg.CreatedAt = replyCreatedAt @@ -249,13 +249,7 @@ func (h *ChatHandler) sendAzureMessage( logger.Error("failed to save reply history message: ", res.Error) } - // 计算本次对话消耗的总 token 数量 - var totalTokens = 0 - if functionCall { // prompt + 函数名 + 参数 token - totalTokens = promptToken + replyToken - } else { - totalTokens = replyToken + getTotalTokens(req) - } + // 更新用户信息 h.db.Model(&model.User{}).Where("id = ?", userVo.Id). UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) } diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go index 83378e03..0ea51f2a 100644 --- a/api/handler/chat_handler.go +++ b/api/handler/chat_handler.go @@ -291,6 +291,19 @@ func (h *ChatHandler) Tokens(c *gin.Context) { return } + // 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文) + if data.Text == "" { + var item model.HistoryMessage + userId, _ := c.Get(types.LoginUserID) + res := h.db.Where("user_id = ?", userId).Last(&item) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + resp.SUCCESS(c, item.Tokens) + return + } + tokens, err := utils.CalcTokens(data.Text, data.Model) if err != nil { resp.ERROR(c, err.Error()) diff --git a/api/handler/chatglm_handler.go b/api/handler/chatglm_handler.go index c97c2ce5..90844a8a 100644 --- a/api/handler/chatglm_handler.go +++ b/api/handler/chatglm_handler.go @@ -146,9 +146,8 @@ func (h *ChatHandler) sendChatGLMMessage( // for reply // 计算本次对话消耗的总 token 数量 - var replyToken = 0 - replyToken, _ = utils.CalcTokens(message.Content, req.Model) - + replyToken, _ := utils.CalcTokens(message.Content, req.Model) + totalTokens := replyToken + getTotalTokens(req) historyReplyMsg := model.HistoryMessage{ UserId: userVo.Id, ChatId: session.ChatId, @@ -156,7 +155,7 @@ func (h *ChatHandler) sendChatGLMMessage( Type: types.ReplyMsg, Icon: role.Icon, Content: message.Content, - Tokens: replyToken, + Tokens: totalTokens, UseContext: true, } historyReplyMsg.CreatedAt = replyCreatedAt @@ -165,10 +164,7 @@ func (h *ChatHandler) sendChatGLMMessage( if res.Error != nil { logger.Error("failed to save reply history message: ", res.Error) } - - // 计算本次对话消耗的总 token 数量 - var totalTokens = 0 - totalTokens = replyToken + getTotalTokens(req) + // 更新用户信息 h.db.Model(&model.User{}).Where("id = ?", userVo.Id). UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) } diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 5cef45a7..7da7ba21 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -95,6 +95,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { } data.Key = utils.Sha256(data.Prompt) + wsClient := h.App.MjTaskClients.Get(data.Key) //logger.Info(data.Prompt, ",", key) if data.Status == Finished { var task types.MjTask @@ -104,10 +105,18 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { resp.SUCCESS(c) return } + if wsClient != nil && data.ReferenceId != "" { + content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt) + utils.ReplyMessage(wsClient, content) + } // download image imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL) if err != nil { logger.Error("error with download image: ", err) + if wsClient != nil && data.ReferenceId != "" { + content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error()) + utils.ReplyMessage(wsClient, content) + } resp.ERROR(c, err.Error()) return } @@ -144,8 +153,6 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { } } - // 推送消息到客户端 - wsClient := h.App.MjTaskClients.Get(data.Key) if wsClient == nil { // 客户端断线,则丢弃 logger.Errorf("Client is offline: %+v", data) resp.SUCCESS(c, "Client is offline") diff --git a/api/handler/openai_handler.go b/api/handler/openai_handler.go index bcafc6ce..73798b8b 100644 --- a/api/handler/openai_handler.go +++ b/api/handler/openai_handler.go @@ -220,17 +220,17 @@ func (h *ChatHandler) sendOpenAiMessage( logger.Error("failed to save prompt history message: ", res.Error) } - // for reply // 计算本次对话消耗的总 token 数量 - var replyToken = 0 - if functionCall { // 函数名 + 参数 token + var totalTokens = 0 + if functionCall { // prompt + 函数名 + 参数 token tokens, _ := utils.CalcTokens(functionName, req.Model) - replyToken += tokens + totalTokens += tokens tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model) - replyToken += tokens + totalTokens += tokens } else { - replyToken, _ = utils.CalcTokens(message.Content, req.Model) + totalTokens, _ = utils.CalcTokens(message.Content, req.Model) } + totalTokens += getTotalTokens(req) historyReplyMsg := model.HistoryMessage{ UserId: userVo.Id, @@ -239,7 +239,7 @@ func (h *ChatHandler) sendOpenAiMessage( Type: types.ReplyMsg, Icon: role.Icon, Content: message.Content, - Tokens: replyToken, + Tokens: totalTokens, UseContext: useContext, } historyReplyMsg.CreatedAt = replyCreatedAt @@ -249,13 +249,7 @@ func (h *ChatHandler) sendOpenAiMessage( logger.Error("failed to save reply history message: ", res.Error) } - // 计算本次对话消耗的总 token 数量 - var totalTokens = 0 - if functionCall { // prompt + 函数名 + 参数 token - totalTokens = promptToken + replyToken - } else { - totalTokens = replyToken + getTotalTokens(req) - } + // 更新用户信息 h.db.Model(&model.User{}).Where("id = ?", userVo.Id). UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) } diff --git a/api/main.go b/api/main.go index 2faae0c7..2860392f 100644 --- a/api/main.go +++ b/api/main.go @@ -161,7 +161,7 @@ func main() { group.GET("remove", h.Remove) group.GET("history", h.History) group.GET("clear", h.Clear) - group.GET("tokens", h.Tokens) + group.POST("tokens", h.Tokens) group.GET("stop", h.StopGenerate) }), fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) { diff --git a/api/service/function/mid_journey.go b/api/service/function/mid_journey.go index cf5455e4..237ec08c 100644 --- a/api/service/function/mid_journey.go +++ b/api/service/function/mid_journey.go @@ -31,13 +31,25 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { logger.Infof("MJ 绘画参数:%+v", params) prompt := utils.InterfaceToString(params["prompt"]) - if !utils.IsEmptyValue(params["ar"]) { - prompt = fmt.Sprintf("%s --ar %s", prompt, params["ar"]) + if !utils.IsEmptyValue(params["--ar"]) { + prompt = fmt.Sprintf("%s --ar %s", prompt, params["--ar"]) delete(params, "--ar") } - if !utils.IsEmptyValue(params["niji"]) { - prompt = fmt.Sprintf("%s --niji %s", prompt, params["niji"]) - delete(params, "niji") + if !utils.IsEmptyValue(params["--s"]) { + prompt = fmt.Sprintf("%s --s %s", prompt, params["--s"]) + delete(params, "--s") + } + if !utils.IsEmptyValue(params["--seed"]) { + prompt = fmt.Sprintf("%s --seed %s", prompt, params["--seed"]) + delete(params, "--seed") + } + if !utils.IsEmptyValue(params["--no"]) { + prompt = fmt.Sprintf("%s --no %s", prompt, params["--no"]) + delete(params, "--no") + } + if !utils.IsEmptyValue(params["--niji"]) { + prompt = fmt.Sprintf("%s --niji %s", prompt, params["--niji"]) + delete(params, "--niji") } else { prompt = prompt + " --v 5.2" } diff --git a/web/public/images/avatar/gpt.png b/web/public/images/avatar/gpt.png index 873ce1d6..67b6a102 100644 Binary files a/web/public/images/avatar/gpt.png and b/web/public/images/avatar/gpt.png differ diff --git a/web/src/assets/css/chat-plus.css b/web/src/assets/css/chat-plus.css index c6e52bf4..af3e7cc9 100644 --- a/web/src/assets/css/chat-plus.css +++ b/web/src/assets/css/chat-plus.css @@ -14,7 +14,6 @@ font-size: 20px; } #app .common-layout .el-aside .title-box .logo { - background-color: #fff; border-radius: 8px; width: 35px; height: 35px; diff --git a/web/src/components/ChatMidJourney.vue b/web/src/components/ChatMidJourney.vue index 6db8523e..1de544ea 100644 --- a/web/src/components/ChatMidJourney.vue +++ b/web/src/components/ChatMidJourney.vue @@ -140,8 +140,8 @@ const send = (url, index) => { margin-right 20px; img { - width: 30px; - height: 30px; + width: 36px; + height: 36px; border-radius: 10px; padding: 1px; } diff --git a/web/src/components/ChatPrompt.vue b/web/src/components/ChatPrompt.vue index 7bf9a105..4bc07991 100644 --- a/web/src/components/ChatPrompt.vue +++ b/web/src/components/ChatPrompt.vue @@ -19,9 +19,8 @@