mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-19 17:56:39 +08:00
feat: plugin function is ready
This commit is contained in:
parent
d014d418e9
commit
d8ff5987dd
@ -2,6 +2,7 @@ package core
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
|
"chatplus/service/function"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
@ -30,9 +31,14 @@ type AppServer struct {
|
|||||||
ChatSession *types.LMap[string, types.ChatSession] //map[sessionId]UserId
|
ChatSession *types.LMap[string, types.ChatSession] //map[sessionId]UserId
|
||||||
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
|
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
|
||||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||||
|
Functions map[string]function.Function
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(appConfig *types.AppConfig) *AppServer {
|
func NewServer(
|
||||||
|
appConfig *types.AppConfig,
|
||||||
|
funZaoBao function.FuncZaoBao,
|
||||||
|
funZhiHu function.FuncHeadlines,
|
||||||
|
funWeibo function.FuncWeiboHot) *AppServer {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
gin.DefaultWriter = io.Discard
|
gin.DefaultWriter = io.Discard
|
||||||
return &AppServer{
|
return &AppServer{
|
||||||
@ -43,6 +49,11 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
|
|||||||
ChatSession: types.NewLMap[string, types.ChatSession](),
|
ChatSession: types.NewLMap[string, types.ChatSession](),
|
||||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||||
|
Functions: map[string]function.Function{
|
||||||
|
types.FuncZaoBao: funZaoBao,
|
||||||
|
types.FuncWeibo: funWeibo,
|
||||||
|
types.FuncHeadLine: funZhiHu,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,9 +22,15 @@ type Property struct {
|
|||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
FuncZaoBao = "zao_bao" // 每日早报
|
||||||
|
FuncHeadLine = "headline" // 今日头条
|
||||||
|
FuncWeibo = "weibo_hot" // 微博热搜
|
||||||
|
)
|
||||||
|
|
||||||
var InnerFunctions = []Function{
|
var InnerFunctions = []Function{
|
||||||
{
|
{
|
||||||
Name: "zao_bao",
|
Name: FuncZaoBao,
|
||||||
Description: "每日早报,获取当天全球的热门新闻事件列表",
|
Description: "每日早报,获取当天全球的热门新闻事件列表",
|
||||||
Parameters: Parameters{
|
Parameters: Parameters{
|
||||||
|
|
||||||
@ -39,7 +45,7 @@ var InnerFunctions = []Function{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "weibo_hot",
|
Name: FuncWeibo,
|
||||||
Description: "新浪微博热搜榜,微博当日热搜榜单",
|
Description: "新浪微博热搜榜,微博当日热搜榜单",
|
||||||
Parameters: Parameters{
|
Parameters: Parameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
@ -54,8 +60,8 @@ var InnerFunctions = []Function{
|
|||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
Name: "zhihu_top",
|
Name: FuncHeadLine,
|
||||||
Description: "知乎热榜,知乎当日话题讨论榜单",
|
Description: "今日头条,给用户推荐当天的头条新闻,周榜热文",
|
||||||
Parameters: Parameters{
|
Parameters: Parameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]Property{
|
Properties: map[string]Property{
|
||||||
@ -68,9 +74,3 @@ var InnerFunctions = []Function{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var FunctionNameMap = map[string]string{
|
|
||||||
"zao_bao": "每日早报",
|
|
||||||
"weibo_hot": "微博热搜",
|
|
||||||
"zhihu_top": "知乎热榜",
|
|
||||||
}
|
|
||||||
|
@ -5,7 +5,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"chatplus/core"
|
"chatplus/core"
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/service/function"
|
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
@ -30,12 +29,11 @@ const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。"
|
|||||||
|
|
||||||
type ChatHandler struct {
|
type ChatHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
funcZaoBao *function.FuncZaoBao
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, zaoBao *function.FuncZaoBao) *ChatHandler {
|
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
|
||||||
handler := ChatHandler{db: db, funcZaoBao: zaoBao}
|
handler := ChatHandler{db: db}
|
||||||
handler.App = app
|
handler.App = app
|
||||||
return &handler
|
return &handler
|
||||||
}
|
}
|
||||||
@ -279,8 +277,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
if !utils.IsEmptyValue(fun) {
|
if !utils.IsEmptyValue(fun) {
|
||||||
functionCall = true
|
functionCall = true
|
||||||
functionName = fun.Name
|
functionName = fun.Name
|
||||||
|
f := h.App.Functions[functionName]
|
||||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", types.FunctionNameMap[functionName])})
|
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -308,8 +307,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
if functionCall { // 调用函数完成任务
|
if functionCall { // 调用函数完成任务
|
||||||
logger.Info(functionName)
|
logger.Info(functionName)
|
||||||
logger.Info(arguments)
|
logger.Info(arguments)
|
||||||
|
f := h.App.Functions[functionName]
|
||||||
// TODO 调用函数完成任务
|
// TODO 调用函数完成任务
|
||||||
data, err := h.funcZaoBao.Fetch()
|
data, err := f.Invoke(arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
replyChunkMessage(ws, types.WsMessage{
|
replyChunkMessage(ws, types.WsMessage{
|
||||||
Type: types.WsMiddle,
|
Type: types.WsMiddle,
|
||||||
@ -338,19 +338,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
message.Content = strings.Join(contents, "")
|
message.Content = strings.Join(contents, "")
|
||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
useMsg := types.Message{Role: "user", Content: prompt}
|
||||||
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
|
||||||
var totalTokens = 0
|
|
||||||
if functionCall { // 函数名 + 参数 token
|
|
||||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
|
||||||
totalTokens += tokens
|
|
||||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
|
||||||
totalTokens += tokens
|
|
||||||
} else {
|
|
||||||
req.Messages = append(req.Messages, message)
|
|
||||||
totalTokens += getTotalTokens(req)
|
|
||||||
}
|
|
||||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("`本轮对话共消耗 Token 数量: %d`", totalTokens)})
|
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||||
if userVo.ChatConfig.EnableContext && functionCall == false {
|
if userVo.ChatConfig.EnableContext && functionCall == false {
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||||
@ -409,9 +396,20 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
logger.Error("failed to save reply history message: ", res.Error)
|
logger.Error("failed to save reply history message: ", res.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 统计用户 token 数量
|
// 计算本次对话消耗的总 token 数量
|
||||||
|
var totalTokens = 0
|
||||||
|
if functionCall { // 函数名 + 参数 token
|
||||||
|
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
||||||
|
totalTokens += tokens
|
||||||
|
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||||
|
totalTokens += tokens
|
||||||
|
} else {
|
||||||
|
req.Messages = append(req.Messages, message)
|
||||||
|
totalTokens += getTotalTokens(req)
|
||||||
|
}
|
||||||
|
//replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)})
|
||||||
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
|
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
|
||||||
historyUserMsg.Tokens+historyReplyMsg.Tokens))
|
totalTokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存当前会话
|
// 保存当前会话
|
||||||
|
@ -46,7 +46,7 @@ type CodeStats struct {
|
|||||||
// Token 生成自验证 token
|
// Token 生成自验证 token
|
||||||
func (h *VerifyHandler) Token(c *gin.Context) {
|
func (h *VerifyHandler) Token(c *gin.Context) {
|
||||||
// 如果不是通过浏览器访问,则返回错误的 token
|
// 如果不是通过浏览器访问,则返回错误的 token
|
||||||
if c.GetHeader("Sec-Fetch-Mode") != "cors" {
|
if c.GetHeader("Sec-Invoke-Mode") != "cors" {
|
||||||
token := fmt.Sprintf("%s:%d", utils.RandString(32), time.Now().Unix())
|
token := fmt.Sprintf("%s:%d", utils.RandString(32), time.Now().Unix())
|
||||||
encrypt, err := utils.AesEncrypt(h.App.Config.AesEncryptKey, []byte(token))
|
encrypt, err := utils.AesEncrypt(h.App.Config.AesEncryptKey, []byte(token))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
24
api/main.go
24
api/main.go
@ -102,12 +102,26 @@ func main() {
|
|||||||
}),
|
}),
|
||||||
|
|
||||||
// 创建函数
|
// 创建函数
|
||||||
fx.Provide(func() (*function.FuncZaoBao, error) {
|
fx.Provide(func() (function.FuncZaoBao, error) {
|
||||||
token := os.Getenv("AL_API_TOKEN")
|
apiToken := os.Getenv("AL_API_TOKEN")
|
||||||
if token == "" {
|
if apiToken == "" {
|
||||||
return nil, errors.New("invalid AL api token")
|
return function.FuncZaoBao{}, errors.New("invalid AL api token")
|
||||||
}
|
}
|
||||||
return function.NewZaoBao(token), nil
|
return function.NewZaoBao(apiToken), nil
|
||||||
|
}),
|
||||||
|
fx.Provide(func() (function.FuncWeiboHot, error) {
|
||||||
|
apiToken := os.Getenv("AL_API_TOKEN")
|
||||||
|
if apiToken == "" {
|
||||||
|
return function.FuncWeiboHot{}, errors.New("invalid AL api token")
|
||||||
|
}
|
||||||
|
return function.NewWeiboHot(apiToken), nil
|
||||||
|
}),
|
||||||
|
fx.Provide(func() (function.FuncHeadlines, error) {
|
||||||
|
apiToken := os.Getenv("AL_API_TOKEN")
|
||||||
|
if apiToken == "" {
|
||||||
|
return function.FuncHeadlines{}, errors.New("invalid AL api token")
|
||||||
|
}
|
||||||
|
return function.NewHeadLines(apiToken), nil
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// 创建控制器
|
// 创建控制器
|
||||||
|
11
api/service/function/function.go
Normal file
11
api/service/function/function.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package function
|
||||||
|
|
||||||
|
type Function interface {
|
||||||
|
Invoke(...interface{}) (string, error)
|
||||||
|
Name() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type resVo struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
}
|
60
api/service/function/tou_tiao.go
Normal file
60
api/service/function/tou_tiao.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package function
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/utils"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 今日头条函数实现
|
||||||
|
|
||||||
|
type FuncHeadlines struct {
|
||||||
|
name string
|
||||||
|
apiURL string
|
||||||
|
token string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHeadLines(token string) FuncHeadlines {
|
||||||
|
return FuncHeadlines{name: "今日头条", apiURL: "https://v2.alapi.cn/api/tophub/get", token: token}
|
||||||
|
}
|
||||||
|
|
||||||
|
type HeadLineVo struct {
|
||||||
|
resVo
|
||||||
|
Data struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
LastUpdate string `json:"last_update"`
|
||||||
|
List []struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
Link string `json:"link"`
|
||||||
|
Other string `json:"other"`
|
||||||
|
} `json:"list"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f FuncHeadlines) Invoke(...interface{}) (string, error) {
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s?type=toutiao&token=%s", f.apiURL, f.token)
|
||||||
|
bytes, err := utils.HttpGet(url, "")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var res HeadLineVo
|
||||||
|
err = utils.JsonDecode(string(bytes), &res)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != 200 {
|
||||||
|
return "", fmt.Errorf("call api fail: %s", res.Msg)
|
||||||
|
}
|
||||||
|
builder := make([]string, 0)
|
||||||
|
builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Name, res.Data.LastUpdate))
|
||||||
|
for i, v := range res.Data.List {
|
||||||
|
builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [%s]", i+1, v.Title, v.Link, v.Other))
|
||||||
|
}
|
||||||
|
return strings.Join(builder, "\n\n"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f FuncHeadlines) Name() string {
|
||||||
|
return f.name
|
||||||
|
}
|
56
api/service/function/weibo_hot.go
Normal file
56
api/service/function/weibo_hot.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
package function
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/utils"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 微博热搜函数实现
|
||||||
|
|
||||||
|
type FuncWeiboHot struct {
|
||||||
|
name string
|
||||||
|
apiURL string
|
||||||
|
token string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWeiboHot(token string) FuncWeiboHot {
|
||||||
|
return FuncWeiboHot{name: "微博热搜", apiURL: "https://v2.alapi.cn/api/new/wbtop", token: token}
|
||||||
|
}
|
||||||
|
|
||||||
|
type WeiBoVo struct {
|
||||||
|
resVo
|
||||||
|
Data []struct {
|
||||||
|
HotWord string `json:"hot_word"`
|
||||||
|
HotWordNum int `json:"hot_word_num"`
|
||||||
|
Url string `json:"url"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f FuncWeiboHot) Invoke(...interface{}) (string, error) {
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s?num=10&token=%s", f.apiURL, f.token)
|
||||||
|
bytes, err := utils.HttpGet(url, "")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var res WeiBoVo
|
||||||
|
err = utils.JsonDecode(string(bytes), &res)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != 200 {
|
||||||
|
return "", fmt.Errorf("call api fail: %s", res.Msg)
|
||||||
|
}
|
||||||
|
builder := make([]string, 0)
|
||||||
|
builder = append(builder, "**新浪微博今日热搜:**")
|
||||||
|
for i, v := range res.Data {
|
||||||
|
builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%d]", i+1, v.HotWord, v.Url, v.HotWordNum))
|
||||||
|
}
|
||||||
|
return strings.Join(builder, "\n\n"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f FuncWeiboHot) Name() string {
|
||||||
|
return f.name
|
||||||
|
}
|
@ -9,17 +9,17 @@ import (
|
|||||||
// 每日早报函数实现
|
// 每日早报函数实现
|
||||||
|
|
||||||
type FuncZaoBao struct {
|
type FuncZaoBao struct {
|
||||||
|
name string
|
||||||
apiURL string
|
apiURL string
|
||||||
token string
|
token string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewZaoBao(token string) *FuncZaoBao {
|
func NewZaoBao(token string) FuncZaoBao {
|
||||||
return &FuncZaoBao{apiURL: "https://v2.alapi.cn/api/zaobao", token: token}
|
return FuncZaoBao{name: "每日早报", apiURL: "https://v2.alapi.cn/api/zaobao", token: token}
|
||||||
}
|
}
|
||||||
|
|
||||||
type resVo struct {
|
type ZaoBaoVo struct {
|
||||||
Code int `json:"code"`
|
resVo
|
||||||
Msg string `json:"msg"`
|
|
||||||
Data struct {
|
Data struct {
|
||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
News []string `json:"news"`
|
News []string `json:"news"`
|
||||||
@ -27,14 +27,14 @@ type resVo struct {
|
|||||||
} `json:"data"`
|
} `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *FuncZaoBao) Fetch() (string, error) {
|
func (f FuncZaoBao) Invoke(...interface{}) (string, error) {
|
||||||
|
|
||||||
url := fmt.Sprintf("%s?format=json&token=%s", f.apiURL, f.token)
|
url := fmt.Sprintf("%s?format=json&token=%s", f.apiURL, f.token)
|
||||||
bytes, err := utils.HttpGet(url, "")
|
bytes, err := utils.HttpGet(url, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
var res resVo
|
var res ZaoBaoVo
|
||||||
err = utils.JsonDecode(string(bytes), &res)
|
err = utils.JsonDecode(string(bytes), &res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@ -49,3 +49,7 @@ func (f *FuncZaoBao) Fetch() (string, error) {
|
|||||||
builder = append(builder, fmt.Sprintf("%s", res.Data.WeiYu))
|
builder = append(builder, fmt.Sprintf("%s", res.Data.WeiYu))
|
||||||
return strings.Join(builder, "\n\n"), nil
|
return strings.Join(builder, "\n\n"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f FuncZaoBao) Name() string {
|
||||||
|
return f.name
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user