feat: plugin function is ready

This commit is contained in:
RockYang 2023-07-15 21:52:30 +08:00
parent d014d418e9
commit d8ff5987dd
9 changed files with 200 additions and 46 deletions

View File

@ -2,6 +2,7 @@ package core
import (
"chatplus/core/types"
"chatplus/service/function"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
@ -30,9 +31,14 @@ type AppServer struct {
ChatSession *types.LMap[string, types.ChatSession] //map[sessionId]UserId
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
Functions map[string]function.Function
}
func NewServer(appConfig *types.AppConfig) *AppServer {
func NewServer(
appConfig *types.AppConfig,
funZaoBao function.FuncZaoBao,
funZhiHu function.FuncHeadlines,
funWeibo function.FuncWeiboHot) *AppServer {
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
return &AppServer{
@ -43,6 +49,11 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
ChatSession: types.NewLMap[string, types.ChatSession](),
ChatClients: types.NewLMap[string, *types.WsClient](),
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
Functions: map[string]function.Function{
types.FuncZaoBao: funZaoBao,
types.FuncWeibo: funWeibo,
types.FuncHeadLine: funZhiHu,
},
}
}

View File

@ -22,9 +22,15 @@ type Property struct {
Description string `json:"description"`
}
const (
FuncZaoBao = "zao_bao" // 每日早报
FuncHeadLine = "headline" // 今日头条
FuncWeibo = "weibo_hot" // 微博热搜
)
var InnerFunctions = []Function{
{
Name: "zao_bao",
Name: FuncZaoBao,
Description: "每日早报,获取当天全球的热门新闻事件列表",
Parameters: Parameters{
@ -39,7 +45,7 @@ var InnerFunctions = []Function{
},
},
{
Name: "weibo_hot",
Name: FuncWeibo,
Description: "新浪微博热搜榜,微博当日热搜榜单",
Parameters: Parameters{
Type: "object",
@ -54,8 +60,8 @@ var InnerFunctions = []Function{
},
{
Name: "zhihu_top",
Description: "知乎热榜,知乎当日话题讨论榜单",
Name: FuncHeadLine,
Description: "今日头条,给用户推荐当天的头条新闻,周榜热文",
Parameters: Parameters{
Type: "object",
Properties: map[string]Property{
@ -68,9 +74,3 @@ var InnerFunctions = []Function{
},
},
}
var FunctionNameMap = map[string]string{
"zao_bao": "每日早报",
"weibo_hot": "微博热搜",
"zhihu_top": "知乎热榜",
}

View File

@ -5,7 +5,6 @@ import (
"bytes"
"chatplus/core"
"chatplus/core/types"
"chatplus/service/function"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
@ -30,12 +29,11 @@ const ErrorMsg = "抱歉AI 助手开小差了,请稍后再试。"
type ChatHandler struct {
BaseHandler
db *gorm.DB
funcZaoBao *function.FuncZaoBao
db *gorm.DB
}
func NewChatHandler(app *core.AppServer, db *gorm.DB, zaoBao *function.FuncZaoBao) *ChatHandler {
handler := ChatHandler{db: db, funcZaoBao: zaoBao}
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
handler := ChatHandler{db: db}
handler.App = app
return &handler
}
@ -279,8 +277,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
if !utils.IsEmptyValue(fun) {
functionCall = true
functionName = fun.Name
f := h.App.Functions[functionName]
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", types.FunctionNameMap[functionName])})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
continue
}
@ -308,8 +307,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
if functionCall { // 调用函数完成任务
logger.Info(functionName)
logger.Info(arguments)
f := h.App.Functions[functionName]
// TODO 调用函数完成任务
data, err := h.funcZaoBao.Fetch()
data, err := f.Invoke(arguments)
if err != nil {
replyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
@ -338,19 +338,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 计算本次对话消耗的总 token 数量
var totalTokens = 0
if functionCall { // 函数名 + 参数 token
tokens, _ := utils.CalcTokens(functionName, req.Model)
totalTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
totalTokens += tokens
} else {
req.Messages = append(req.Messages, message)
totalTokens += getTotalTokens(req)
}
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("`本轮对话共消耗 Token 数量: %d`", totalTokens)})
// 更新上下文消息,如果是调用函数则不需要更新上下文
if userVo.ChatConfig.EnableContext && functionCall == false {
chatCtx = append(chatCtx, useMsg) // 提问消息
@ -409,9 +396,20 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
logger.Error("failed to save reply history message: ", res.Error)
}
// 统计用户 token 数量
// 计算本次对话消耗的总 token 数量
var totalTokens = 0
if functionCall { // 函数名 + 参数 token
tokens, _ := utils.CalcTokens(functionName, req.Model)
totalTokens += tokens
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
totalTokens += tokens
} else {
req.Messages = append(req.Messages, message)
totalTokens += getTotalTokens(req)
}
//replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)})
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
historyUserMsg.Tokens+historyReplyMsg.Tokens))
totalTokens))
}
// 保存当前会话

View File

@ -46,7 +46,7 @@ type CodeStats struct {
// Token 生成自验证 token
func (h *VerifyHandler) Token(c *gin.Context) {
// 如果不是通过浏览器访问,则返回错误的 token
if c.GetHeader("Sec-Fetch-Mode") != "cors" {
if c.GetHeader("Sec-Invoke-Mode") != "cors" {
token := fmt.Sprintf("%s:%d", utils.RandString(32), time.Now().Unix())
encrypt, err := utils.AesEncrypt(h.App.Config.AesEncryptKey, []byte(token))
if err != nil {

View File

@ -102,12 +102,26 @@ func main() {
}),
// 创建函数
fx.Provide(func() (*function.FuncZaoBao, error) {
token := os.Getenv("AL_API_TOKEN")
if token == "" {
return nil, errors.New("invalid AL api token")
fx.Provide(func() (function.FuncZaoBao, error) {
apiToken := os.Getenv("AL_API_TOKEN")
if apiToken == "" {
return function.FuncZaoBao{}, errors.New("invalid AL api token")
}
return function.NewZaoBao(token), nil
return function.NewZaoBao(apiToken), nil
}),
fx.Provide(func() (function.FuncWeiboHot, error) {
apiToken := os.Getenv("AL_API_TOKEN")
if apiToken == "" {
return function.FuncWeiboHot{}, errors.New("invalid AL api token")
}
return function.NewWeiboHot(apiToken), nil
}),
fx.Provide(func() (function.FuncHeadlines, error) {
apiToken := os.Getenv("AL_API_TOKEN")
if apiToken == "" {
return function.FuncHeadlines{}, errors.New("invalid AL api token")
}
return function.NewHeadLines(apiToken), nil
}),
// 创建控制器

View File

@ -0,0 +1,11 @@
package function
type Function interface {
Invoke(...interface{}) (string, error)
Name() string
}
type resVo struct {
Code int `json:"code"`
Msg string `json:"msg"`
}

View File

@ -0,0 +1,60 @@
package function
import (
"chatplus/utils"
"fmt"
"strings"
)
// 今日头条函数实现
type FuncHeadlines struct {
name string
apiURL string
token string
}
func NewHeadLines(token string) FuncHeadlines {
return FuncHeadlines{name: "今日头条", apiURL: "https://v2.alapi.cn/api/tophub/get", token: token}
}
type HeadLineVo struct {
resVo
Data struct {
Name string `json:"name"`
LastUpdate string `json:"last_update"`
List []struct {
Title string `json:"title"`
Link string `json:"link"`
Other string `json:"other"`
} `json:"list"`
} `json:"data"`
}
func (f FuncHeadlines) Invoke(...interface{}) (string, error) {
url := fmt.Sprintf("%s?type=toutiao&token=%s", f.apiURL, f.token)
bytes, err := utils.HttpGet(url, "")
if err != nil {
return "", err
}
var res HeadLineVo
err = utils.JsonDecode(string(bytes), &res)
if err != nil {
return "", err
}
if res.Code != 200 {
return "", fmt.Errorf("call api fail: %s", res.Msg)
}
builder := make([]string, 0)
builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Name, res.Data.LastUpdate))
for i, v := range res.Data.List {
builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [%s]", i+1, v.Title, v.Link, v.Other))
}
return strings.Join(builder, "\n\n"), nil
}
func (f FuncHeadlines) Name() string {
return f.name
}

View File

@ -0,0 +1,56 @@
package function
import (
"chatplus/utils"
"fmt"
"strings"
)
// 微博热搜函数实现
type FuncWeiboHot struct {
name string
apiURL string
token string
}
func NewWeiboHot(token string) FuncWeiboHot {
return FuncWeiboHot{name: "微博热搜", apiURL: "https://v2.alapi.cn/api/new/wbtop", token: token}
}
type WeiBoVo struct {
resVo
Data []struct {
HotWord string `json:"hot_word"`
HotWordNum int `json:"hot_word_num"`
Url string `json:"url"`
} `json:"data"`
}
func (f FuncWeiboHot) Invoke(...interface{}) (string, error) {
url := fmt.Sprintf("%s?num=10&token=%s", f.apiURL, f.token)
bytes, err := utils.HttpGet(url, "")
if err != nil {
return "", err
}
var res WeiBoVo
err = utils.JsonDecode(string(bytes), &res)
if err != nil {
return "", err
}
if res.Code != 200 {
return "", fmt.Errorf("call api fail: %s", res.Msg)
}
builder := make([]string, 0)
builder = append(builder, "**新浪微博今日热搜:**")
for i, v := range res.Data {
builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%d]", i+1, v.HotWord, v.Url, v.HotWordNum))
}
return strings.Join(builder, "\n\n"), nil
}
func (f FuncWeiboHot) Name() string {
return f.name
}

View File

@ -9,17 +9,17 @@ import (
// 每日早报函数实现
type FuncZaoBao struct {
name string
apiURL string
token string
}
func NewZaoBao(token string) *FuncZaoBao {
return &FuncZaoBao{apiURL: "https://v2.alapi.cn/api/zaobao", token: token}
func NewZaoBao(token string) FuncZaoBao {
return FuncZaoBao{name: "每日早报", apiURL: "https://v2.alapi.cn/api/zaobao", token: token}
}
type resVo struct {
Code int `json:"code"`
Msg string `json:"msg"`
type ZaoBaoVo struct {
resVo
Data struct {
Date string `json:"date"`
News []string `json:"news"`
@ -27,14 +27,14 @@ type resVo struct {
} `json:"data"`
}
func (f *FuncZaoBao) Fetch() (string, error) {
func (f FuncZaoBao) Invoke(...interface{}) (string, error) {
url := fmt.Sprintf("%s?format=json&token=%s", f.apiURL, f.token)
bytes, err := utils.HttpGet(url, "")
if err != nil {
return "", err
}
var res resVo
var res ZaoBaoVo
err = utils.JsonDecode(string(bytes), &res)
if err != nil {
return "", err
@ -49,3 +49,7 @@ func (f *FuncZaoBao) Fetch() (string, error) {
builder = append(builder, fmt.Sprintf("%s", res.Data.WeiYu))
return strings.Join(builder, "\n\n"), nil
}
func (f FuncZaoBao) Name() string {
return f.name
}