diff --git a/api/core/types/function.go b/api/core/types/function.go
index c422118d..dd0a6e4e 100644
--- a/api/core/types/function.go
+++ b/api/core/types/function.go
@@ -85,15 +85,27 @@ var InnerFunctions = []Function{
Type: "string",
Description: "绘画内容描述,提示词,如果该参数中有中文的话,则需要翻译成英文",
},
- "ar": {
+ "--ar": {
Type: "string",
Description: "图片长宽比,默认值 16:9",
},
- "niji": {
+ "--niji": {
Type: "string",
Description: "动漫模型版本,默认值空",
},
- "v": {
+ "--s": {
+ Type: "string",
+ Description: "风格,stylize",
+ },
+ "--seed": {
+ Type: "string",
+ Description: "随机种子",
+ },
+ "--no": {
+ Type: "string",
+ Description: "负面提示词,指定不要什么元素或者风格,如果该参数中有中文的话,则需要翻译成英文",
+ },
+ "--v": {
Type: "string",
Description: "模型版本,默认值: 5.2",
},
diff --git a/api/handler/azure_handler.go b/api/handler/azure_handler.go
index 39fc8b3c..9238036c 100644
--- a/api/handler/azure_handler.go
+++ b/api/handler/azure_handler.go
@@ -220,17 +220,17 @@ func (h *ChatHandler) sendAzureMessage(
logger.Error("failed to save prompt history message: ", res.Error)
}
- // for reply
// 计算本次对话消耗的总 token 数量
- var replyToken = 0
- if functionCall { // 函数名 + 参数 token
+ var totalTokens = 0
+ if functionCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(functionName, req.Model)
- replyToken += tokens
+ totalTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
- replyToken += tokens
+ totalTokens += tokens
} else {
- replyToken, _ = utils.CalcTokens(message.Content, req.Model)
+ totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
+ totalTokens += getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
@@ -239,7 +239,7 @@ func (h *ChatHandler) sendAzureMessage(
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
- Tokens: replyToken,
+ Tokens: totalTokens,
UseContext: useContext,
}
historyReplyMsg.CreatedAt = replyCreatedAt
@@ -249,13 +249,7 @@ func (h *ChatHandler) sendAzureMessage(
logger.Error("failed to save reply history message: ", res.Error)
}
- // 计算本次对话消耗的总 token 数量
- var totalTokens = 0
- if functionCall { // prompt + 函数名 + 参数 token
- totalTokens = promptToken + replyToken
- } else {
- totalTokens = replyToken + getTotalTokens(req)
- }
+ // 更新用户信息
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
}
diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go
index 83378e03..0ea51f2a 100644
--- a/api/handler/chat_handler.go
+++ b/api/handler/chat_handler.go
@@ -291,6 +291,19 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
return
}
+ // 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文)
+ if data.Text == "" {
+ var item model.HistoryMessage
+ userId, _ := c.Get(types.LoginUserID)
+ res := h.db.Where("user_id = ?", userId).Last(&item)
+ if res.Error != nil {
+ resp.ERROR(c, res.Error.Error())
+ return
+ }
+ resp.SUCCESS(c, item.Tokens)
+ return
+ }
+
tokens, err := utils.CalcTokens(data.Text, data.Model)
if err != nil {
resp.ERROR(c, err.Error())
diff --git a/api/handler/chatglm_handler.go b/api/handler/chatglm_handler.go
index c97c2ce5..90844a8a 100644
--- a/api/handler/chatglm_handler.go
+++ b/api/handler/chatglm_handler.go
@@ -146,9 +146,8 @@ func (h *ChatHandler) sendChatGLMMessage(
// for reply
// 计算本次对话消耗的总 token 数量
- var replyToken = 0
- replyToken, _ = utils.CalcTokens(message.Content, req.Model)
-
+ replyToken, _ := utils.CalcTokens(message.Content, req.Model)
+ totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
@@ -156,7 +155,7 @@ func (h *ChatHandler) sendChatGLMMessage(
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
- Tokens: replyToken,
+ Tokens: totalTokens,
UseContext: true,
}
historyReplyMsg.CreatedAt = replyCreatedAt
@@ -165,10 +164,7 @@ func (h *ChatHandler) sendChatGLMMessage(
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
-
- // 计算本次对话消耗的总 token 数量
- var totalTokens = 0
- totalTokens = replyToken + getTotalTokens(req)
+ // 更新用户信息
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
}
diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go
index 5cef45a7..7da7ba21 100644
--- a/api/handler/mj_handler.go
+++ b/api/handler/mj_handler.go
@@ -95,6 +95,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
}
data.Key = utils.Sha256(data.Prompt)
+ wsClient := h.App.MjTaskClients.Get(data.Key)
//logger.Info(data.Prompt, ",", key)
if data.Status == Finished {
var task types.MjTask
@@ -104,10 +105,18 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
resp.SUCCESS(c)
return
}
+ if wsClient != nil && data.ReferenceId != "" {
+ content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
+ utils.ReplyMessage(wsClient, content)
+ }
// download image
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
if err != nil {
logger.Error("error with download image: ", err)
+ if wsClient != nil && data.ReferenceId != "" {
+ content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
+ utils.ReplyMessage(wsClient, content)
+ }
resp.ERROR(c, err.Error())
return
}
@@ -144,8 +153,6 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
}
}
- // 推送消息到客户端
- wsClient := h.App.MjTaskClients.Get(data.Key)
if wsClient == nil { // 客户端断线,则丢弃
logger.Errorf("Client is offline: %+v", data)
resp.SUCCESS(c, "Client is offline")
diff --git a/api/handler/openai_handler.go b/api/handler/openai_handler.go
index bcafc6ce..73798b8b 100644
--- a/api/handler/openai_handler.go
+++ b/api/handler/openai_handler.go
@@ -220,17 +220,17 @@ func (h *ChatHandler) sendOpenAiMessage(
logger.Error("failed to save prompt history message: ", res.Error)
}
- // for reply
// 计算本次对话消耗的总 token 数量
- var replyToken = 0
- if functionCall { // 函数名 + 参数 token
+ var totalTokens = 0
+ if functionCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(functionName, req.Model)
- replyToken += tokens
+ totalTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
- replyToken += tokens
+ totalTokens += tokens
} else {
- replyToken, _ = utils.CalcTokens(message.Content, req.Model)
+ totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
}
+ totalTokens += getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
@@ -239,7 +239,7 @@ func (h *ChatHandler) sendOpenAiMessage(
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
- Tokens: replyToken,
+ Tokens: totalTokens,
UseContext: useContext,
}
historyReplyMsg.CreatedAt = replyCreatedAt
@@ -249,13 +249,7 @@ func (h *ChatHandler) sendOpenAiMessage(
logger.Error("failed to save reply history message: ", res.Error)
}
- // 计算本次对话消耗的总 token 数量
- var totalTokens = 0
- if functionCall { // prompt + 函数名 + 参数 token
- totalTokens = promptToken + replyToken
- } else {
- totalTokens = replyToken + getTotalTokens(req)
- }
+ // 更新用户信息
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
}
diff --git a/api/main.go b/api/main.go
index 2faae0c7..2860392f 100644
--- a/api/main.go
+++ b/api/main.go
@@ -161,7 +161,7 @@ func main() {
group.GET("remove", h.Remove)
group.GET("history", h.History)
group.GET("clear", h.Clear)
- group.GET("tokens", h.Tokens)
+ group.POST("tokens", h.Tokens)
group.GET("stop", h.StopGenerate)
}),
fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) {
diff --git a/api/service/function/mid_journey.go b/api/service/function/mid_journey.go
index cf5455e4..237ec08c 100644
--- a/api/service/function/mid_journey.go
+++ b/api/service/function/mid_journey.go
@@ -31,13 +31,25 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
logger.Infof("MJ 绘画参数:%+v", params)
prompt := utils.InterfaceToString(params["prompt"])
- if !utils.IsEmptyValue(params["ar"]) {
- prompt = fmt.Sprintf("%s --ar %s", prompt, params["ar"])
+ if !utils.IsEmptyValue(params["--ar"]) {
+ prompt = fmt.Sprintf("%s --ar %s", prompt, params["--ar"])
delete(params, "--ar")
}
- if !utils.IsEmptyValue(params["niji"]) {
- prompt = fmt.Sprintf("%s --niji %s", prompt, params["niji"])
- delete(params, "niji")
+ if !utils.IsEmptyValue(params["--s"]) {
+ prompt = fmt.Sprintf("%s --s %s", prompt, params["--s"])
+ delete(params, "--s")
+ }
+ if !utils.IsEmptyValue(params["--seed"]) {
+ prompt = fmt.Sprintf("%s --seed %s", prompt, params["--seed"])
+ delete(params, "--seed")
+ }
+ if !utils.IsEmptyValue(params["--no"]) {
+ prompt = fmt.Sprintf("%s --no %s", prompt, params["--no"])
+ delete(params, "--no")
+ }
+ if !utils.IsEmptyValue(params["--niji"]) {
+ prompt = fmt.Sprintf("%s --niji %s", prompt, params["--niji"])
+ delete(params, "--niji")
} else {
prompt = prompt + " --v 5.2"
}
diff --git a/web/public/images/avatar/gpt.png b/web/public/images/avatar/gpt.png
index 873ce1d6..67b6a102 100644
Binary files a/web/public/images/avatar/gpt.png and b/web/public/images/avatar/gpt.png differ
diff --git a/web/src/assets/css/chat-plus.css b/web/src/assets/css/chat-plus.css
index c6e52bf4..af3e7cc9 100644
--- a/web/src/assets/css/chat-plus.css
+++ b/web/src/assets/css/chat-plus.css
@@ -14,7 +14,6 @@
font-size: 20px;
}
#app .common-layout .el-aside .title-box .logo {
- background-color: #fff;
border-radius: 8px;
width: 35px;
height: 35px;
diff --git a/web/src/components/ChatMidJourney.vue b/web/src/components/ChatMidJourney.vue
index 6db8523e..1de544ea 100644
--- a/web/src/components/ChatMidJourney.vue
+++ b/web/src/components/ChatMidJourney.vue
@@ -140,8 +140,8 @@ const send = (url, index) => {
margin-right 20px;
img {
- width: 30px;
- height: 30px;
+ width: 36px;
+ height: 36px;
border-radius: 10px;
padding: 1px;
}
diff --git a/web/src/components/ChatPrompt.vue b/web/src/components/ChatPrompt.vue
index 7bf9a105..4bc07991 100644
--- a/web/src/components/ChatPrompt.vue
+++ b/web/src/components/ChatPrompt.vue
@@ -19,9 +19,8 @@