opt: optimize the styles of chat page; caculate all tokens of context as chat history's token

This commit is contained in:
RockYang 2023-09-11 13:34:20 +08:00
parent e2c18c4e1e
commit 3cc8c3284a
14 changed files with 95 additions and 60 deletions

View File

@ -85,15 +85,27 @@ var InnerFunctions = []Function{
Type: "string", Type: "string",
Description: "绘画内容描述,提示词,如果该参数中有中文的话,则需要翻译成英文", Description: "绘画内容描述,提示词,如果该参数中有中文的话,则需要翻译成英文",
}, },
"ar": { "--ar": {
Type: "string", Type: "string",
Description: "图片长宽比,默认值 16:9", Description: "图片长宽比,默认值 16:9",
}, },
"niji": { "--niji": {
Type: "string", Type: "string",
Description: "动漫模型版本,默认值空", Description: "动漫模型版本,默认值空",
}, },
"v": { "--s": {
Type: "string",
Description: "风格stylize",
},
"--seed": {
Type: "string",
Description: "随机种子",
},
"--no": {
Type: "string",
Description: "负面提示词,指定不要什么元素或者风格,如果该参数中有中文的话,则需要翻译成英文",
},
"--v": {
Type: "string", Type: "string",
Description: "模型版本,默认值: 5.2", Description: "模型版本,默认值: 5.2",
}, },

View File

@ -220,17 +220,17 @@ func (h *ChatHandler) sendAzureMessage(
logger.Error("failed to save prompt history message: ", res.Error) logger.Error("failed to save prompt history message: ", res.Error)
} }
// for reply
// 计算本次对话消耗的总 token 数量 // 计算本次对话消耗的总 token 数量
var replyToken = 0 var totalTokens = 0
if functionCall { // 函数名 + 参数 token if functionCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(functionName, req.Model) tokens, _ := utils.CalcTokens(functionName, req.Model)
replyToken += tokens totalTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model) tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
replyToken += tokens totalTokens += tokens
} else { } else {
replyToken, _ = utils.CalcTokens(message.Content, req.Model) totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
} }
totalTokens += getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{ historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id, UserId: userVo.Id,
@ -239,7 +239,7 @@ func (h *ChatHandler) sendAzureMessage(
Type: types.ReplyMsg, Type: types.ReplyMsg,
Icon: role.Icon, Icon: role.Icon,
Content: message.Content, Content: message.Content,
Tokens: replyToken, Tokens: totalTokens,
UseContext: useContext, UseContext: useContext,
} }
historyReplyMsg.CreatedAt = replyCreatedAt historyReplyMsg.CreatedAt = replyCreatedAt
@ -249,13 +249,7 @@ func (h *ChatHandler) sendAzureMessage(
logger.Error("failed to save reply history message: ", res.Error) logger.Error("failed to save reply history message: ", res.Error)
} }
// 计算本次对话消耗的总 token 数量 // 更新用户信息
var totalTokens = 0
if functionCall { // prompt + 函数名 + 参数 token
totalTokens = promptToken + replyToken
} else {
totalTokens = replyToken + getTotalTokens(req)
}
h.db.Model(&model.User{}).Where("id = ?", userVo.Id). h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
} }

View File

@ -291,6 +291,19 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
return return
} }
// 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文)
if data.Text == "" {
var item model.HistoryMessage
userId, _ := c.Get(types.LoginUserID)
res := h.db.Where("user_id = ?", userId).Last(&item)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
resp.SUCCESS(c, item.Tokens)
return
}
tokens, err := utils.CalcTokens(data.Text, data.Model) tokens, err := utils.CalcTokens(data.Text, data.Model)
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())

View File

@ -146,9 +146,8 @@ func (h *ChatHandler) sendChatGLMMessage(
// for reply // for reply
// 计算本次对话消耗的总 token 数量 // 计算本次对话消耗的总 token 数量
var replyToken = 0 replyToken, _ := utils.CalcTokens(message.Content, req.Model)
replyToken, _ = utils.CalcTokens(message.Content, req.Model) totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{ historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id, UserId: userVo.Id,
ChatId: session.ChatId, ChatId: session.ChatId,
@ -156,7 +155,7 @@ func (h *ChatHandler) sendChatGLMMessage(
Type: types.ReplyMsg, Type: types.ReplyMsg,
Icon: role.Icon, Icon: role.Icon,
Content: message.Content, Content: message.Content,
Tokens: replyToken, Tokens: totalTokens,
UseContext: true, UseContext: true,
} }
historyReplyMsg.CreatedAt = replyCreatedAt historyReplyMsg.CreatedAt = replyCreatedAt
@ -165,10 +164,7 @@ func (h *ChatHandler) sendChatGLMMessage(
if res.Error != nil { if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error) logger.Error("failed to save reply history message: ", res.Error)
} }
// 更新用户信息
// 计算本次对话消耗的总 token 数量
var totalTokens = 0
totalTokens = replyToken + getTotalTokens(req)
h.db.Model(&model.User{}).Where("id = ?", userVo.Id). h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
} }

View File

@ -95,6 +95,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
} }
data.Key = utils.Sha256(data.Prompt) data.Key = utils.Sha256(data.Prompt)
wsClient := h.App.MjTaskClients.Get(data.Key)
//logger.Info(data.Prompt, ",", key) //logger.Info(data.Prompt, ",", key)
if data.Status == Finished { if data.Status == Finished {
var task types.MjTask var task types.MjTask
@ -104,10 +105,18 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
resp.SUCCESS(c) resp.SUCCESS(c)
return return
} }
if wsClient != nil && data.ReferenceId != "" {
content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
utils.ReplyMessage(wsClient, content)
}
// download image // download image
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL) imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
if err != nil { if err != nil {
logger.Error("error with download image: ", err) logger.Error("error with download image: ", err)
if wsClient != nil && data.ReferenceId != "" {
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
utils.ReplyMessage(wsClient, content)
}
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return
} }
@ -144,8 +153,6 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
} }
} }
// 推送消息到客户端
wsClient := h.App.MjTaskClients.Get(data.Key)
if wsClient == nil { // 客户端断线,则丢弃 if wsClient == nil { // 客户端断线,则丢弃
logger.Errorf("Client is offline: %+v", data) logger.Errorf("Client is offline: %+v", data)
resp.SUCCESS(c, "Client is offline") resp.SUCCESS(c, "Client is offline")

View File

@ -220,17 +220,17 @@ func (h *ChatHandler) sendOpenAiMessage(
logger.Error("failed to save prompt history message: ", res.Error) logger.Error("failed to save prompt history message: ", res.Error)
} }
// for reply
// 计算本次对话消耗的总 token 数量 // 计算本次对话消耗的总 token 数量
var replyToken = 0 var totalTokens = 0
if functionCall { // 函数名 + 参数 token if functionCall { // prompt + 函数名 + 参数 token
tokens, _ := utils.CalcTokens(functionName, req.Model) tokens, _ := utils.CalcTokens(functionName, req.Model)
replyToken += tokens totalTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model) tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
replyToken += tokens totalTokens += tokens
} else { } else {
replyToken, _ = utils.CalcTokens(message.Content, req.Model) totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
} }
totalTokens += getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{ historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id, UserId: userVo.Id,
@ -239,7 +239,7 @@ func (h *ChatHandler) sendOpenAiMessage(
Type: types.ReplyMsg, Type: types.ReplyMsg,
Icon: role.Icon, Icon: role.Icon,
Content: message.Content, Content: message.Content,
Tokens: replyToken, Tokens: totalTokens,
UseContext: useContext, UseContext: useContext,
} }
historyReplyMsg.CreatedAt = replyCreatedAt historyReplyMsg.CreatedAt = replyCreatedAt
@ -249,13 +249,7 @@ func (h *ChatHandler) sendOpenAiMessage(
logger.Error("failed to save reply history message: ", res.Error) logger.Error("failed to save reply history message: ", res.Error)
} }
// 计算本次对话消耗的总 token 数量 // 更新用户信息
var totalTokens = 0
if functionCall { // prompt + 函数名 + 参数 token
totalTokens = promptToken + replyToken
} else {
totalTokens = replyToken + getTotalTokens(req)
}
h.db.Model(&model.User{}).Where("id = ?", userVo.Id). h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
} }

View File

@ -161,7 +161,7 @@ func main() {
group.GET("remove", h.Remove) group.GET("remove", h.Remove)
group.GET("history", h.History) group.GET("history", h.History)
group.GET("clear", h.Clear) group.GET("clear", h.Clear)
group.GET("tokens", h.Tokens) group.POST("tokens", h.Tokens)
group.GET("stop", h.StopGenerate) group.GET("stop", h.StopGenerate)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) { fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) {

View File

@ -31,13 +31,25 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
logger.Infof("MJ 绘画参数:%+v", params) logger.Infof("MJ 绘画参数:%+v", params)
prompt := utils.InterfaceToString(params["prompt"]) prompt := utils.InterfaceToString(params["prompt"])
if !utils.IsEmptyValue(params["ar"]) { if !utils.IsEmptyValue(params["--ar"]) {
prompt = fmt.Sprintf("%s --ar %s", prompt, params["ar"]) prompt = fmt.Sprintf("%s --ar %s", prompt, params["--ar"])
delete(params, "--ar") delete(params, "--ar")
} }
if !utils.IsEmptyValue(params["niji"]) { if !utils.IsEmptyValue(params["--s"]) {
prompt = fmt.Sprintf("%s --niji %s", prompt, params["niji"]) prompt = fmt.Sprintf("%s --s %s", prompt, params["--s"])
delete(params, "niji") delete(params, "--s")
}
if !utils.IsEmptyValue(params["--seed"]) {
prompt = fmt.Sprintf("%s --seed %s", prompt, params["--seed"])
delete(params, "--seed")
}
if !utils.IsEmptyValue(params["--no"]) {
prompt = fmt.Sprintf("%s --no %s", prompt, params["--no"])
delete(params, "--no")
}
if !utils.IsEmptyValue(params["--niji"]) {
prompt = fmt.Sprintf("%s --niji %s", prompt, params["--niji"])
delete(params, "--niji")
} else { } else {
prompt = prompt + " --v 5.2" prompt = prompt + " --v 5.2"
} }

Binary file not shown.

Before

Width:  |  Height:  |  Size: 169 KiB

After

Width:  |  Height:  |  Size: 5.8 KiB

View File

@ -14,7 +14,6 @@
font-size: 20px; font-size: 20px;
} }
#app .common-layout .el-aside .title-box .logo { #app .common-layout .el-aside .title-box .logo {
background-color: #fff;
border-radius: 8px; border-radius: 8px;
width: 35px; width: 35px;
height: 35px; height: 35px;

View File

@ -140,8 +140,8 @@ const send = (url, index) => {
margin-right 20px; margin-right 20px;
img { img {
width: 30px; width: 36px;
height: 30px; height: 36px;
border-radius: 10px; border-radius: 10px;
padding: 1px; padding: 1px;
} }

View File

@ -19,9 +19,8 @@
<script> <script>
import {defineComponent} from "vue" import {defineComponent} from "vue"
import {dateFormat} from "@/utils/libs";
import {Clock} from "@element-plus/icons-vue"; import {Clock} from "@element-plus/icons-vue";
import {httpGet} from "@/utils/http"; import {httpPost} from "@/utils/http";
export default defineComponent({ export default defineComponent({
name: 'ChatPrompt', name: 'ChatPrompt',
@ -56,7 +55,7 @@ export default defineComponent({
}, },
mounted() { mounted() {
if (!this.finalTokens) { if (!this.finalTokens) {
httpGet(`/api/chat/tokens?text=${this.content}&model=${this.model}`).then(res => { httpPost("/api/chat/tokens", {text: this.content, model: this.model}).then(res => {
this.finalTokens = res.data; this.finalTokens = res.data;
}) })
} }
@ -83,8 +82,8 @@ export default defineComponent({
margin-right 20px; margin-right 20px;
img { img {
width: 30px; width: 36px;
height: 30px; height: 36px;
border-radius: 10px; border-radius: 10px;
padding: 1px; padding: 1px;
} }

View File

@ -86,8 +86,8 @@ export default defineComponent({
margin-right 20px; margin-right 20px;
img { img {
width: 30px; width: 36px;
height: 30px; height: 36px;
border-radius: 10px; border-radius: 10px;
padding: 1px; padding: 1px;
} }

View File

@ -158,7 +158,7 @@
:icon="item.icon" :icon="item.icon"
:created-at="dateFormat(item['created_at'])" :created-at="dateFormat(item['created_at'])"
:tokens="item['tokens']" :tokens="item['tokens']"
:model="modelID" :model="getModelValue(modelID.value)"
:content="item.content"/> :content="item.content"/>
<chat-reply v-else-if="item.type==='reply'" <chat-reply v-else-if="item.type==='reply'"
:icon="item.icon" :icon="item.icon"
@ -601,7 +601,7 @@ const connect = function (chat_id, role_id) {
// token // token
const reply = chatData.value[chatData.value.length - 1] const reply = chatData.value[chatData.value.length - 1]
httpGet(`/api/chat/tokens?text=${reply.orgContent}&model=${modelID.value}`).then(res => { httpPost("/api/chat/tokens", {text: "", model: getModelValue(modelID.value)}).then(res => {
reply['created_at'] = new Date().getTime(); reply['created_at'] = new Date().getTime();
reply['tokens'] = res.data; reply['tokens'] = res.data;
// //
@ -813,7 +813,7 @@ const reGenerate = function () {
chatData.value.push({ chatData.value.push({
type: "prompt", type: "prompt",
id: randString(32), id: randString(32),
icon: 'images/avatar/user.png', icon: loginUser.value.avatar,
content: renderInputText(text) content: renderInputText(text)
}); });
socket.value.send(text); socket.value.send(text);
@ -859,6 +859,15 @@ const getChatById = (chatId) => {
} }
return null return null
} }
const getModelValue = (model_id) => {
for (let i = 0; i < models.value.length; i++) {
if (models.value[i].id === model_id) {
return models.value[i].value
}
}
return ""
}
</script> </script>
<style scoped lang="stylus"> <style scoped lang="stylus">