mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-12 04:03:42 +08:00
feat: plugin function is ready
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
||||
"bytes"
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/function"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
@@ -30,12 +29,11 @@ const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。"
|
||||
|
||||
type ChatHandler struct {
|
||||
BaseHandler
|
||||
db *gorm.DB
|
||||
funcZaoBao *function.FuncZaoBao
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, zaoBao *function.FuncZaoBao) *ChatHandler {
|
||||
handler := ChatHandler{db: db, funcZaoBao: zaoBao}
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
|
||||
handler := ChatHandler{db: db}
|
||||
handler.App = app
|
||||
return &handler
|
||||
}
|
||||
@@ -279,8 +277,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
if !utils.IsEmptyValue(fun) {
|
||||
functionCall = true
|
||||
functionName = fun.Name
|
||||
f := h.App.Functions[functionName]
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", types.FunctionNameMap[functionName])})
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -308,8 +307,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
if functionCall { // 调用函数完成任务
|
||||
logger.Info(functionName)
|
||||
logger.Info(arguments)
|
||||
f := h.App.Functions[functionName]
|
||||
// TODO 调用函数完成任务
|
||||
data, err := h.funcZaoBao.Fetch()
|
||||
data, err := f.Invoke(arguments)
|
||||
if err != nil {
|
||||
replyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
@@ -338,19 +338,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
message.Content = strings.Join(contents, "")
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var totalTokens = 0
|
||||
if functionCall { // 函数名 + 参数 token
|
||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
||||
totalTokens += tokens
|
||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||
totalTokens += tokens
|
||||
} else {
|
||||
req.Messages = append(req.Messages, message)
|
||||
totalTokens += getTotalTokens(req)
|
||||
}
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("`本轮对话共消耗 Token 数量: %d`", totalTokens)})
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if userVo.ChatConfig.EnableContext && functionCall == false {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
@@ -409,9 +396,20 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 统计用户 token 数量
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var totalTokens = 0
|
||||
if functionCall { // 函数名 + 参数 token
|
||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
||||
totalTokens += tokens
|
||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||
totalTokens += tokens
|
||||
} else {
|
||||
req.Messages = append(req.Messages, message)
|
||||
totalTokens += getTotalTokens(req)
|
||||
}
|
||||
//replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)})
|
||||
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
|
||||
historyUserMsg.Tokens+historyReplyMsg.Tokens))
|
||||
totalTokens))
|
||||
}
|
||||
|
||||
// 保存当前会话
|
||||
|
||||
Reference in New Issue
Block a user