mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
opt: optimize the styles of chat page; caculate all tokens of context as chat history's token
This commit is contained in:
parent
e2c18c4e1e
commit
3cc8c3284a
@ -85,15 +85,27 @@ var InnerFunctions = []Function{
|
|||||||
Type: "string",
|
Type: "string",
|
||||||
Description: "绘画内容描述,提示词,如果该参数中有中文的话,则需要翻译成英文",
|
Description: "绘画内容描述,提示词,如果该参数中有中文的话,则需要翻译成英文",
|
||||||
},
|
},
|
||||||
"ar": {
|
"--ar": {
|
||||||
Type: "string",
|
Type: "string",
|
||||||
Description: "图片长宽比,默认值 16:9",
|
Description: "图片长宽比,默认值 16:9",
|
||||||
},
|
},
|
||||||
"niji": {
|
"--niji": {
|
||||||
Type: "string",
|
Type: "string",
|
||||||
Description: "动漫模型版本,默认值空",
|
Description: "动漫模型版本,默认值空",
|
||||||
},
|
},
|
||||||
"v": {
|
"--s": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "风格,stylize",
|
||||||
|
},
|
||||||
|
"--seed": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "随机种子",
|
||||||
|
},
|
||||||
|
"--no": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "负面提示词,指定不要什么元素或者风格,如果该参数中有中文的话,则需要翻译成英文",
|
||||||
|
},
|
||||||
|
"--v": {
|
||||||
Type: "string",
|
Type: "string",
|
||||||
Description: "模型版本,默认值: 5.2",
|
Description: "模型版本,默认值: 5.2",
|
||||||
},
|
},
|
||||||
|
@ -220,17 +220,17 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
logger.Error("failed to save prompt history message: ", res.Error)
|
logger.Error("failed to save prompt history message: ", res.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// for reply
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
// 计算本次对话消耗的总 token 数量
|
||||||
var replyToken = 0
|
var totalTokens = 0
|
||||||
if functionCall { // 函数名 + 参数 token
|
if functionCall { // prompt + 函数名 + 参数 token
|
||||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
||||||
replyToken += tokens
|
totalTokens += tokens
|
||||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||||
replyToken += tokens
|
totalTokens += tokens
|
||||||
} else {
|
} else {
|
||||||
replyToken, _ = utils.CalcTokens(message.Content, req.Model)
|
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
||||||
}
|
}
|
||||||
|
totalTokens += getTotalTokens(req)
|
||||||
|
|
||||||
historyReplyMsg := model.HistoryMessage{
|
historyReplyMsg := model.HistoryMessage{
|
||||||
UserId: userVo.Id,
|
UserId: userVo.Id,
|
||||||
@ -239,7 +239,7 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
Type: types.ReplyMsg,
|
Type: types.ReplyMsg,
|
||||||
Icon: role.Icon,
|
Icon: role.Icon,
|
||||||
Content: message.Content,
|
Content: message.Content,
|
||||||
Tokens: replyToken,
|
Tokens: totalTokens,
|
||||||
UseContext: useContext,
|
UseContext: useContext,
|
||||||
}
|
}
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||||
@ -249,13 +249,7 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
logger.Error("failed to save reply history message: ", res.Error)
|
logger.Error("failed to save reply history message: ", res.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
// 更新用户信息
|
||||||
var totalTokens = 0
|
|
||||||
if functionCall { // prompt + 函数名 + 参数 token
|
|
||||||
totalTokens = promptToken + replyToken
|
|
||||||
} else {
|
|
||||||
totalTokens = replyToken + getTotalTokens(req)
|
|
||||||
}
|
|
||||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
|
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
|
||||||
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
|
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
|
||||||
}
|
}
|
||||||
|
@ -291,6 +291,19 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文)
|
||||||
|
if data.Text == "" {
|
||||||
|
var item model.HistoryMessage
|
||||||
|
userId, _ := c.Get(types.LoginUserID)
|
||||||
|
res := h.db.Where("user_id = ?", userId).Last(&item)
|
||||||
|
if res.Error != nil {
|
||||||
|
resp.ERROR(c, res.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c, item.Tokens)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
tokens, err := utils.CalcTokens(data.Text, data.Model)
|
tokens, err := utils.CalcTokens(data.Text, data.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
|
@ -146,9 +146,8 @@ func (h *ChatHandler) sendChatGLMMessage(
|
|||||||
|
|
||||||
// for reply
|
// for reply
|
||||||
// 计算本次对话消耗的总 token 数量
|
// 计算本次对话消耗的总 token 数量
|
||||||
var replyToken = 0
|
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
|
||||||
replyToken, _ = utils.CalcTokens(message.Content, req.Model)
|
totalTokens := replyToken + getTotalTokens(req)
|
||||||
|
|
||||||
historyReplyMsg := model.HistoryMessage{
|
historyReplyMsg := model.HistoryMessage{
|
||||||
UserId: userVo.Id,
|
UserId: userVo.Id,
|
||||||
ChatId: session.ChatId,
|
ChatId: session.ChatId,
|
||||||
@ -156,7 +155,7 @@ func (h *ChatHandler) sendChatGLMMessage(
|
|||||||
Type: types.ReplyMsg,
|
Type: types.ReplyMsg,
|
||||||
Icon: role.Icon,
|
Icon: role.Icon,
|
||||||
Content: message.Content,
|
Content: message.Content,
|
||||||
Tokens: replyToken,
|
Tokens: totalTokens,
|
||||||
UseContext: true,
|
UseContext: true,
|
||||||
}
|
}
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||||
@ -165,10 +164,7 @@ func (h *ChatHandler) sendChatGLMMessage(
|
|||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
logger.Error("failed to save reply history message: ", res.Error)
|
logger.Error("failed to save reply history message: ", res.Error)
|
||||||
}
|
}
|
||||||
|
// 更新用户信息
|
||||||
// 计算本次对话消耗的总 token 数量
|
|
||||||
var totalTokens = 0
|
|
||||||
totalTokens = replyToken + getTotalTokens(req)
|
|
||||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
|
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
|
||||||
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
|
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
|
||||||
}
|
}
|
||||||
|
@ -95,6 +95,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
data.Key = utils.Sha256(data.Prompt)
|
data.Key = utils.Sha256(data.Prompt)
|
||||||
|
wsClient := h.App.MjTaskClients.Get(data.Key)
|
||||||
//logger.Info(data.Prompt, ",", key)
|
//logger.Info(data.Prompt, ",", key)
|
||||||
if data.Status == Finished {
|
if data.Status == Finished {
|
||||||
var task types.MjTask
|
var task types.MjTask
|
||||||
@ -104,10 +105,18 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
|||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if wsClient != nil && data.ReferenceId != "" {
|
||||||
|
content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
|
||||||
|
utils.ReplyMessage(wsClient, content)
|
||||||
|
}
|
||||||
// download image
|
// download image
|
||||||
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
|
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("error with download image: ", err)
|
logger.Error("error with download image: ", err)
|
||||||
|
if wsClient != nil && data.ReferenceId != "" {
|
||||||
|
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
|
||||||
|
utils.ReplyMessage(wsClient, content)
|
||||||
|
}
|
||||||
resp.ERROR(c, err.Error())
|
resp.ERROR(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -144,8 +153,6 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 推送消息到客户端
|
|
||||||
wsClient := h.App.MjTaskClients.Get(data.Key)
|
|
||||||
if wsClient == nil { // 客户端断线,则丢弃
|
if wsClient == nil { // 客户端断线,则丢弃
|
||||||
logger.Errorf("Client is offline: %+v", data)
|
logger.Errorf("Client is offline: %+v", data)
|
||||||
resp.SUCCESS(c, "Client is offline")
|
resp.SUCCESS(c, "Client is offline")
|
||||||
|
@ -220,17 +220,17 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
logger.Error("failed to save prompt history message: ", res.Error)
|
logger.Error("failed to save prompt history message: ", res.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// for reply
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
// 计算本次对话消耗的总 token 数量
|
||||||
var replyToken = 0
|
var totalTokens = 0
|
||||||
if functionCall { // 函数名 + 参数 token
|
if functionCall { // prompt + 函数名 + 参数 token
|
||||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
||||||
replyToken += tokens
|
totalTokens += tokens
|
||||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||||
replyToken += tokens
|
totalTokens += tokens
|
||||||
} else {
|
} else {
|
||||||
replyToken, _ = utils.CalcTokens(message.Content, req.Model)
|
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
||||||
}
|
}
|
||||||
|
totalTokens += getTotalTokens(req)
|
||||||
|
|
||||||
historyReplyMsg := model.HistoryMessage{
|
historyReplyMsg := model.HistoryMessage{
|
||||||
UserId: userVo.Id,
|
UserId: userVo.Id,
|
||||||
@ -239,7 +239,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
Type: types.ReplyMsg,
|
Type: types.ReplyMsg,
|
||||||
Icon: role.Icon,
|
Icon: role.Icon,
|
||||||
Content: message.Content,
|
Content: message.Content,
|
||||||
Tokens: replyToken,
|
Tokens: totalTokens,
|
||||||
UseContext: useContext,
|
UseContext: useContext,
|
||||||
}
|
}
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||||
@ -249,13 +249,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
logger.Error("failed to save reply history message: ", res.Error)
|
logger.Error("failed to save reply history message: ", res.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
// 更新用户信息
|
||||||
var totalTokens = 0
|
|
||||||
if functionCall { // prompt + 函数名 + 参数 token
|
|
||||||
totalTokens = promptToken + replyToken
|
|
||||||
} else {
|
|
||||||
totalTokens = replyToken + getTotalTokens(req)
|
|
||||||
}
|
|
||||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
|
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
|
||||||
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
|
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
|
||||||
}
|
}
|
||||||
|
@ -161,7 +161,7 @@ func main() {
|
|||||||
group.GET("remove", h.Remove)
|
group.GET("remove", h.Remove)
|
||||||
group.GET("history", h.History)
|
group.GET("history", h.History)
|
||||||
group.GET("clear", h.Clear)
|
group.GET("clear", h.Clear)
|
||||||
group.GET("tokens", h.Tokens)
|
group.POST("tokens", h.Tokens)
|
||||||
group.GET("stop", h.StopGenerate)
|
group.GET("stop", h.StopGenerate)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) {
|
||||||
|
@ -31,13 +31,25 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
|
|||||||
|
|
||||||
logger.Infof("MJ 绘画参数:%+v", params)
|
logger.Infof("MJ 绘画参数:%+v", params)
|
||||||
prompt := utils.InterfaceToString(params["prompt"])
|
prompt := utils.InterfaceToString(params["prompt"])
|
||||||
if !utils.IsEmptyValue(params["ar"]) {
|
if !utils.IsEmptyValue(params["--ar"]) {
|
||||||
prompt = fmt.Sprintf("%s --ar %s", prompt, params["ar"])
|
prompt = fmt.Sprintf("%s --ar %s", prompt, params["--ar"])
|
||||||
delete(params, "--ar")
|
delete(params, "--ar")
|
||||||
}
|
}
|
||||||
if !utils.IsEmptyValue(params["niji"]) {
|
if !utils.IsEmptyValue(params["--s"]) {
|
||||||
prompt = fmt.Sprintf("%s --niji %s", prompt, params["niji"])
|
prompt = fmt.Sprintf("%s --s %s", prompt, params["--s"])
|
||||||
delete(params, "niji")
|
delete(params, "--s")
|
||||||
|
}
|
||||||
|
if !utils.IsEmptyValue(params["--seed"]) {
|
||||||
|
prompt = fmt.Sprintf("%s --seed %s", prompt, params["--seed"])
|
||||||
|
delete(params, "--seed")
|
||||||
|
}
|
||||||
|
if !utils.IsEmptyValue(params["--no"]) {
|
||||||
|
prompt = fmt.Sprintf("%s --no %s", prompt, params["--no"])
|
||||||
|
delete(params, "--no")
|
||||||
|
}
|
||||||
|
if !utils.IsEmptyValue(params["--niji"]) {
|
||||||
|
prompt = fmt.Sprintf("%s --niji %s", prompt, params["--niji"])
|
||||||
|
delete(params, "--niji")
|
||||||
} else {
|
} else {
|
||||||
prompt = prompt + " --v 5.2"
|
prompt = prompt + " --v 5.2"
|
||||||
}
|
}
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 169 KiB After Width: | Height: | Size: 5.8 KiB |
@ -14,7 +14,6 @@
|
|||||||
font-size: 20px;
|
font-size: 20px;
|
||||||
}
|
}
|
||||||
#app .common-layout .el-aside .title-box .logo {
|
#app .common-layout .el-aside .title-box .logo {
|
||||||
background-color: #fff;
|
|
||||||
border-radius: 8px;
|
border-radius: 8px;
|
||||||
width: 35px;
|
width: 35px;
|
||||||
height: 35px;
|
height: 35px;
|
||||||
|
@ -140,8 +140,8 @@ const send = (url, index) => {
|
|||||||
margin-right 20px;
|
margin-right 20px;
|
||||||
|
|
||||||
img {
|
img {
|
||||||
width: 30px;
|
width: 36px;
|
||||||
height: 30px;
|
height: 36px;
|
||||||
border-radius: 10px;
|
border-radius: 10px;
|
||||||
padding: 1px;
|
padding: 1px;
|
||||||
}
|
}
|
||||||
|
@ -19,9 +19,8 @@
|
|||||||
|
|
||||||
<script>
|
<script>
|
||||||
import {defineComponent} from "vue"
|
import {defineComponent} from "vue"
|
||||||
import {dateFormat} from "@/utils/libs";
|
|
||||||
import {Clock} from "@element-plus/icons-vue";
|
import {Clock} from "@element-plus/icons-vue";
|
||||||
import {httpGet} from "@/utils/http";
|
import {httpPost} from "@/utils/http";
|
||||||
|
|
||||||
export default defineComponent({
|
export default defineComponent({
|
||||||
name: 'ChatPrompt',
|
name: 'ChatPrompt',
|
||||||
@ -56,7 +55,7 @@ export default defineComponent({
|
|||||||
},
|
},
|
||||||
mounted() {
|
mounted() {
|
||||||
if (!this.finalTokens) {
|
if (!this.finalTokens) {
|
||||||
httpGet(`/api/chat/tokens?text=${this.content}&model=${this.model}`).then(res => {
|
httpPost("/api/chat/tokens", {text: this.content, model: this.model}).then(res => {
|
||||||
this.finalTokens = res.data;
|
this.finalTokens = res.data;
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -83,8 +82,8 @@ export default defineComponent({
|
|||||||
margin-right 20px;
|
margin-right 20px;
|
||||||
|
|
||||||
img {
|
img {
|
||||||
width: 30px;
|
width: 36px;
|
||||||
height: 30px;
|
height: 36px;
|
||||||
border-radius: 10px;
|
border-radius: 10px;
|
||||||
padding: 1px;
|
padding: 1px;
|
||||||
}
|
}
|
||||||
|
@ -86,8 +86,8 @@ export default defineComponent({
|
|||||||
margin-right 20px;
|
margin-right 20px;
|
||||||
|
|
||||||
img {
|
img {
|
||||||
width: 30px;
|
width: 36px;
|
||||||
height: 30px;
|
height: 36px;
|
||||||
border-radius: 10px;
|
border-radius: 10px;
|
||||||
padding: 1px;
|
padding: 1px;
|
||||||
}
|
}
|
||||||
|
@ -158,7 +158,7 @@
|
|||||||
:icon="item.icon"
|
:icon="item.icon"
|
||||||
:created-at="dateFormat(item['created_at'])"
|
:created-at="dateFormat(item['created_at'])"
|
||||||
:tokens="item['tokens']"
|
:tokens="item['tokens']"
|
||||||
:model="modelID"
|
:model="getModelValue(modelID.value)"
|
||||||
:content="item.content"/>
|
:content="item.content"/>
|
||||||
<chat-reply v-else-if="item.type==='reply'"
|
<chat-reply v-else-if="item.type==='reply'"
|
||||||
:icon="item.icon"
|
:icon="item.icon"
|
||||||
@ -601,7 +601,7 @@ const connect = function (chat_id, role_id) {
|
|||||||
|
|
||||||
// 获取 token
|
// 获取 token
|
||||||
const reply = chatData.value[chatData.value.length - 1]
|
const reply = chatData.value[chatData.value.length - 1]
|
||||||
httpGet(`/api/chat/tokens?text=${reply.orgContent}&model=${modelID.value}`).then(res => {
|
httpPost("/api/chat/tokens", {text: "", model: getModelValue(modelID.value)}).then(res => {
|
||||||
reply['created_at'] = new Date().getTime();
|
reply['created_at'] = new Date().getTime();
|
||||||
reply['tokens'] = res.data;
|
reply['tokens'] = res.data;
|
||||||
// 将聊天框的滚动条滑动到最底部
|
// 将聊天框的滚动条滑动到最底部
|
||||||
@ -813,7 +813,7 @@ const reGenerate = function () {
|
|||||||
chatData.value.push({
|
chatData.value.push({
|
||||||
type: "prompt",
|
type: "prompt",
|
||||||
id: randString(32),
|
id: randString(32),
|
||||||
icon: 'images/avatar/user.png',
|
icon: loginUser.value.avatar,
|
||||||
content: renderInputText(text)
|
content: renderInputText(text)
|
||||||
});
|
});
|
||||||
socket.value.send(text);
|
socket.value.send(text);
|
||||||
@ -859,6 +859,15 @@ const getChatById = (chatId) => {
|
|||||||
}
|
}
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const getModelValue = (model_id) => {
|
||||||
|
for (let i = 0; i < models.value.length; i++) {
|
||||||
|
if (models.value[i].id === model_id) {
|
||||||
|
return models.value[i].value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<style scoped lang="stylus">
|
<style scoped lang="stylus">
|
||||||
|
Loading…
Reference in New Issue
Block a user