mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-09 10:43:44 +08:00
feat: 头条,微博热搜等函数 API 实现
This commit is contained in:
@@ -315,17 +315,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
// TODO 调用函数完成任务
|
||||
data, err := f.Invoke(arguments)
|
||||
if err != nil {
|
||||
msg := "调用函数出错:" + err.Error()
|
||||
replyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: "调用函数出错:" + err.Error(),
|
||||
Content: msg,
|
||||
})
|
||||
contents = append(contents, msg)
|
||||
} else {
|
||||
replyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: data,
|
||||
})
|
||||
contents = append(contents, data)
|
||||
}
|
||||
contents = append(contents, data)
|
||||
}
|
||||
|
||||
// 消息发送成功
|
||||
@@ -359,7 +361,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
}
|
||||
|
||||
// for prompt
|
||||
token, err := utils.CalcTokens(prompt, req.Model)
|
||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
@@ -370,7 +372,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
Type: types.PromptMsg,
|
||||
Icon: user.Avatar,
|
||||
Content: prompt,
|
||||
Tokens: token,
|
||||
Tokens: promptToken,
|
||||
UseContext: useContext,
|
||||
}
|
||||
historyUserMsg.CreatedAt = promptCreatedAt
|
||||
@@ -381,10 +383,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
}
|
||||
|
||||
// for reply
|
||||
token, err = utils.CalcTokens(message.Content, req.Model)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var replyToken = 0
|
||||
if functionCall { // 函数名 + 参数 token
|
||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
||||
replyToken += tokens
|
||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||
replyToken += tokens
|
||||
} else {
|
||||
replyToken, _ = utils.CalcTokens(message.Content, req.Model)
|
||||
}
|
||||
|
||||
historyReplyMsg := model.HistoryMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
@@ -392,7 +401,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: token,
|
||||
Tokens: replyToken,
|
||||
UseContext: useContext,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
@@ -404,14 +413,10 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var totalTokens = 0
|
||||
if functionCall { // 函数名 + 参数 token
|
||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
||||
totalTokens += tokens
|
||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||
totalTokens += tokens
|
||||
if functionCall { // prompt + 函数名 + 参数 token
|
||||
totalTokens = promptToken + replyToken
|
||||
} else {
|
||||
req.Messages = append(req.Messages, message)
|
||||
totalTokens += getTotalTokens(req)
|
||||
totalTokens = replyToken + getTotalTokens(req)
|
||||
}
|
||||
//replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)})
|
||||
if userVo.ChatConfig.ApiKey != "" { // 调用自己的 API KEY 不计算 token 消耗
|
||||
|
||||
Reference in New Issue
Block a user