feat: plugin function is ready

This commit is contained in:
RockYang
2023-07-15 21:52:30 +08:00
parent 3e41edd3b5
commit accf8eeb77
9 changed files with 200 additions and 46 deletions

View File

@@ -5,7 +5,6 @@ import (
"bytes"
"chatplus/core"
"chatplus/core/types"
"chatplus/service/function"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
@@ -30,12 +29,11 @@ const ErrorMsg = "抱歉AI 助手开小差了,请稍后再试。"
type ChatHandler struct {
BaseHandler
db *gorm.DB
funcZaoBao *function.FuncZaoBao
db *gorm.DB
}
func NewChatHandler(app *core.AppServer, db *gorm.DB, zaoBao *function.FuncZaoBao) *ChatHandler {
handler := ChatHandler{db: db, funcZaoBao: zaoBao}
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
handler := ChatHandler{db: db}
handler.App = app
return &handler
}
@@ -279,8 +277,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
if !utils.IsEmptyValue(fun) {
functionCall = true
functionName = fun.Name
f := h.App.Functions[functionName]
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", types.FunctionNameMap[functionName])})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
continue
}
@@ -308,8 +307,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
if functionCall { // 调用函数完成任务
logger.Info(functionName)
logger.Info(arguments)
f := h.App.Functions[functionName]
// TODO 调用函数完成任务
data, err := h.funcZaoBao.Fetch()
data, err := f.Invoke(arguments)
if err != nil {
replyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
@@ -338,19 +338,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 计算本次对话消耗的总 token 数量
var totalTokens = 0
if functionCall { // 函数名 + 参数 token
tokens, _ := utils.CalcTokens(functionName, req.Model)
totalTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
totalTokens += tokens
} else {
req.Messages = append(req.Messages, message)
totalTokens += getTotalTokens(req)
}
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("`本轮对话共消耗 Token 数量: %d`", totalTokens)})
// 更新上下文消息,如果是调用函数则不需要更新上下文
if userVo.ChatConfig.EnableContext && functionCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息
@@ -409,9 +396,20 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
logger.Error("failed to save reply history message: ", res.Error)
}
// 统计用户 token 数量
// 计算本次对话消耗的总 token 数量
var totalTokens = 0
if functionCall { // 函数名 + 参数 token
tokens, _ := utils.CalcTokens(functionName, req.Model)
totalTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
totalTokens += tokens
} else {
req.Messages = append(req.Messages, message)
totalTokens += getTotalTokens(req)
}
//replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)})
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
historyUserMsg.Tokens+historyReplyMsg.Tokens))
totalTokens))
}
// 保存当前会话