From d8ff5987dde6b8fe2777f8598bf7443229d16991 Mon Sep 17 00:00:00 2001 From: RockYang Date: Sat, 15 Jul 2023 21:52:30 +0800 Subject: [PATCH] feat: plugin function is ready --- api/core/app_server.go | 13 ++++++- api/core/types/function.go | 20 +++++------ api/handler/chat_handler.go | 42 +++++++++++----------- api/handler/verify_handler.go | 2 +- api/main.go | 24 ++++++++++--- api/service/function/function.go | 11 ++++++ api/service/function/tou_tiao.go | 60 +++++++++++++++++++++++++++++++ api/service/function/weibo_hot.go | 56 +++++++++++++++++++++++++++++ api/service/function/zao_bao.go | 18 ++++++---- 9 files changed, 200 insertions(+), 46 deletions(-) create mode 100644 api/service/function/function.go create mode 100644 api/service/function/tou_tiao.go create mode 100644 api/service/function/weibo_hot.go diff --git a/api/core/app_server.go b/api/core/app_server.go index 83253a80..22fa9640 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -2,6 +2,7 @@ package core import ( "chatplus/core/types" + "chatplus/service/function" "chatplus/store/model" "chatplus/utils" "chatplus/utils/resp" @@ -30,9 +31,14 @@ type AppServer struct { ChatSession *types.LMap[string, types.ChatSession] //map[sessionId]UserId ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合 ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function + Functions map[string]function.Function } -func NewServer(appConfig *types.AppConfig) *AppServer { +func NewServer( + appConfig *types.AppConfig, + funZaoBao function.FuncZaoBao, + funZhiHu function.FuncHeadlines, + funWeibo function.FuncWeiboHot) *AppServer { gin.SetMode(gin.ReleaseMode) gin.DefaultWriter = io.Discard return &AppServer{ @@ -43,6 +49,11 @@ func NewServer(appConfig *types.AppConfig) *AppServer { ChatSession: types.NewLMap[string, types.ChatSession](), ChatClients: types.NewLMap[string, *types.WsClient](), ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), + Functions: map[string]function.Function{ + types.FuncZaoBao: funZaoBao, + types.FuncWeibo: funWeibo, + types.FuncHeadLine: funZhiHu, + }, } } diff --git a/api/core/types/function.go b/api/core/types/function.go index b37cbeac..71e1e44d 100644 --- a/api/core/types/function.go +++ b/api/core/types/function.go @@ -22,9 +22,15 @@ type Property struct { Description string `json:"description"` } +const ( + FuncZaoBao = "zao_bao" // 每日早报 + FuncHeadLine = "headline" // 今日头条 + FuncWeibo = "weibo_hot" // 微博热搜 +) + var InnerFunctions = []Function{ { - Name: "zao_bao", + Name: FuncZaoBao, Description: "每日早报,获取当天全球的热门新闻事件列表", Parameters: Parameters{ @@ -39,7 +45,7 @@ var InnerFunctions = []Function{ }, }, { - Name: "weibo_hot", + Name: FuncWeibo, Description: "新浪微博热搜榜,微博当日热搜榜单", Parameters: Parameters{ Type: "object", @@ -54,8 +60,8 @@ var InnerFunctions = []Function{ }, { - Name: "zhihu_top", - Description: "知乎热榜,知乎当日话题讨论榜单", + Name: FuncHeadLine, + Description: "今日头条,给用户推荐当天的头条新闻,周榜热文", Parameters: Parameters{ Type: "object", Properties: map[string]Property{ @@ -68,9 +74,3 @@ var InnerFunctions = []Function{ }, }, } - -var FunctionNameMap = map[string]string{ - "zao_bao": "每日早报", - "weibo_hot": "微博热搜", - "zhihu_top": "知乎热榜", -} diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go index 0eebb59e..0101a66e 100644 --- a/api/handler/chat_handler.go +++ b/api/handler/chat_handler.go @@ -5,7 +5,6 @@ import ( "bytes" "chatplus/core" "chatplus/core/types" - "chatplus/service/function" "chatplus/store/model" "chatplus/store/vo" "chatplus/utils" @@ -30,12 +29,11 @@ const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。" type ChatHandler struct { BaseHandler - db *gorm.DB - funcZaoBao *function.FuncZaoBao + db *gorm.DB } -func NewChatHandler(app *core.AppServer, db *gorm.DB, zaoBao *function.FuncZaoBao) *ChatHandler { - handler := ChatHandler{db: db, funcZaoBao: zaoBao} +func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler { + handler := ChatHandler{db: db} handler.App = app return &handler } @@ -279,8 +277,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession if !utils.IsEmptyValue(fun) { functionCall = true functionName = fun.Name + f := h.App.Functions[functionName] replyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", types.FunctionNameMap[functionName])}) + replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())}) continue } @@ -308,8 +307,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession if functionCall { // 调用函数完成任务 logger.Info(functionName) logger.Info(arguments) + f := h.App.Functions[functionName] // TODO 调用函数完成任务 - data, err := h.funcZaoBao.Fetch() + data, err := f.Invoke(arguments) if err != nil { replyChunkMessage(ws, types.WsMessage{ Type: types.WsMiddle, @@ -338,19 +338,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession message.Content = strings.Join(contents, "") useMsg := types.Message{Role: "user", Content: prompt} - // 计算本次对话消耗的总 token 数量 - var totalTokens = 0 - if functionCall { // 函数名 + 参数 token - tokens, _ := utils.CalcTokens(functionName, req.Model) - totalTokens += tokens - tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model) - totalTokens += tokens - } else { - req.Messages = append(req.Messages, message) - totalTokens += getTotalTokens(req) - } - replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("`本轮对话共消耗 Token 数量: %d`", totalTokens)}) - // 更新上下文消息,如果是调用函数则不需要更新上下文 if userVo.ChatConfig.EnableContext && functionCall == false { chatCtx = append(chatCtx, useMsg) // 提问消息 @@ -409,9 +396,20 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession logger.Error("failed to save reply history message: ", res.Error) } - // 统计用户 token 数量 + // 计算本次对话消耗的总 token 数量 + var totalTokens = 0 + if functionCall { // 函数名 + 参数 token + tokens, _ := utils.CalcTokens(functionName, req.Model) + totalTokens += tokens + tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model) + totalTokens += tokens + } else { + req.Messages = append(req.Messages, message) + totalTokens += getTotalTokens(req) + } + //replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)}) h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?", - historyUserMsg.Tokens+historyReplyMsg.Tokens)) + totalTokens)) } // 保存当前会话 diff --git a/api/handler/verify_handler.go b/api/handler/verify_handler.go index 7de5a31b..801a6b62 100644 --- a/api/handler/verify_handler.go +++ b/api/handler/verify_handler.go @@ -46,7 +46,7 @@ type CodeStats struct { // Token 生成自验证 token func (h *VerifyHandler) Token(c *gin.Context) { // 如果不是通过浏览器访问,则返回错误的 token - if c.GetHeader("Sec-Fetch-Mode") != "cors" { + if c.GetHeader("Sec-Invoke-Mode") != "cors" { token := fmt.Sprintf("%s:%d", utils.RandString(32), time.Now().Unix()) encrypt, err := utils.AesEncrypt(h.App.Config.AesEncryptKey, []byte(token)) if err != nil { diff --git a/api/main.go b/api/main.go index 07a6dc92..0b8b8e3a 100644 --- a/api/main.go +++ b/api/main.go @@ -102,12 +102,26 @@ func main() { }), // 创建函数 - fx.Provide(func() (*function.FuncZaoBao, error) { - token := os.Getenv("AL_API_TOKEN") - if token == "" { - return nil, errors.New("invalid AL api token") + fx.Provide(func() (function.FuncZaoBao, error) { + apiToken := os.Getenv("AL_API_TOKEN") + if apiToken == "" { + return function.FuncZaoBao{}, errors.New("invalid AL api token") } - return function.NewZaoBao(token), nil + return function.NewZaoBao(apiToken), nil + }), + fx.Provide(func() (function.FuncWeiboHot, error) { + apiToken := os.Getenv("AL_API_TOKEN") + if apiToken == "" { + return function.FuncWeiboHot{}, errors.New("invalid AL api token") + } + return function.NewWeiboHot(apiToken), nil + }), + fx.Provide(func() (function.FuncHeadlines, error) { + apiToken := os.Getenv("AL_API_TOKEN") + if apiToken == "" { + return function.FuncHeadlines{}, errors.New("invalid AL api token") + } + return function.NewHeadLines(apiToken), nil }), // 创建控制器 diff --git a/api/service/function/function.go b/api/service/function/function.go new file mode 100644 index 00000000..a0e8c964 --- /dev/null +++ b/api/service/function/function.go @@ -0,0 +1,11 @@ +package function + +type Function interface { + Invoke(...interface{}) (string, error) + Name() string +} + +type resVo struct { + Code int `json:"code"` + Msg string `json:"msg"` +} diff --git a/api/service/function/tou_tiao.go b/api/service/function/tou_tiao.go new file mode 100644 index 00000000..e9e2402c --- /dev/null +++ b/api/service/function/tou_tiao.go @@ -0,0 +1,60 @@ +package function + +import ( + "chatplus/utils" + "fmt" + "strings" +) + +// 今日头条函数实现 + +type FuncHeadlines struct { + name string + apiURL string + token string +} + +func NewHeadLines(token string) FuncHeadlines { + return FuncHeadlines{name: "今日头条", apiURL: "https://v2.alapi.cn/api/tophub/get", token: token} +} + +type HeadLineVo struct { + resVo + Data struct { + Name string `json:"name"` + LastUpdate string `json:"last_update"` + List []struct { + Title string `json:"title"` + Link string `json:"link"` + Other string `json:"other"` + } `json:"list"` + } `json:"data"` +} + +func (f FuncHeadlines) Invoke(...interface{}) (string, error) { + + url := fmt.Sprintf("%s?type=toutiao&token=%s", f.apiURL, f.token) + bytes, err := utils.HttpGet(url, "") + if err != nil { + return "", err + } + var res HeadLineVo + err = utils.JsonDecode(string(bytes), &res) + if err != nil { + return "", err + } + + if res.Code != 200 { + return "", fmt.Errorf("call api fail: %s", res.Msg) + } + builder := make([]string, 0) + builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Name, res.Data.LastUpdate)) + for i, v := range res.Data.List { + builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [%s]", i+1, v.Title, v.Link, v.Other)) + } + return strings.Join(builder, "\n\n"), nil +} + +func (f FuncHeadlines) Name() string { + return f.name +} diff --git a/api/service/function/weibo_hot.go b/api/service/function/weibo_hot.go new file mode 100644 index 00000000..ef568498 --- /dev/null +++ b/api/service/function/weibo_hot.go @@ -0,0 +1,56 @@ +package function + +import ( + "chatplus/utils" + "fmt" + "strings" +) + +// 微博热搜函数实现 + +type FuncWeiboHot struct { + name string + apiURL string + token string +} + +func NewWeiboHot(token string) FuncWeiboHot { + return FuncWeiboHot{name: "微博热搜", apiURL: "https://v2.alapi.cn/api/new/wbtop", token: token} +} + +type WeiBoVo struct { + resVo + Data []struct { + HotWord string `json:"hot_word"` + HotWordNum int `json:"hot_word_num"` + Url string `json:"url"` + } `json:"data"` +} + +func (f FuncWeiboHot) Invoke(...interface{}) (string, error) { + + url := fmt.Sprintf("%s?num=10&token=%s", f.apiURL, f.token) + bytes, err := utils.HttpGet(url, "") + if err != nil { + return "", err + } + var res WeiBoVo + err = utils.JsonDecode(string(bytes), &res) + if err != nil { + return "", err + } + + if res.Code != 200 { + return "", fmt.Errorf("call api fail: %s", res.Msg) + } + builder := make([]string, 0) + builder = append(builder, "**新浪微博今日热搜:**") + for i, v := range res.Data { + builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%d]", i+1, v.HotWord, v.Url, v.HotWordNum)) + } + return strings.Join(builder, "\n\n"), nil +} + +func (f FuncWeiboHot) Name() string { + return f.name +} diff --git a/api/service/function/zao_bao.go b/api/service/function/zao_bao.go index 851425d0..8065b237 100644 --- a/api/service/function/zao_bao.go +++ b/api/service/function/zao_bao.go @@ -9,17 +9,17 @@ import ( // 每日早报函数实现 type FuncZaoBao struct { + name string apiURL string token string } -func NewZaoBao(token string) *FuncZaoBao { - return &FuncZaoBao{apiURL: "https://v2.alapi.cn/api/zaobao", token: token} +func NewZaoBao(token string) FuncZaoBao { + return FuncZaoBao{name: "每日早报", apiURL: "https://v2.alapi.cn/api/zaobao", token: token} } -type resVo struct { - Code int `json:"code"` - Msg string `json:"msg"` +type ZaoBaoVo struct { + resVo Data struct { Date string `json:"date"` News []string `json:"news"` @@ -27,14 +27,14 @@ type resVo struct { } `json:"data"` } -func (f *FuncZaoBao) Fetch() (string, error) { +func (f FuncZaoBao) Invoke(...interface{}) (string, error) { url := fmt.Sprintf("%s?format=json&token=%s", f.apiURL, f.token) bytes, err := utils.HttpGet(url, "") if err != nil { return "", err } - var res resVo + var res ZaoBaoVo err = utils.JsonDecode(string(bytes), &res) if err != nil { return "", err @@ -49,3 +49,7 @@ func (f *FuncZaoBao) Fetch() (string, error) { builder = append(builder, fmt.Sprintf("%s", res.Data.WeiYu)) return strings.Join(builder, "\n\n"), nil } + +func (f FuncZaoBao) Name() string { + return f.name +}