mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-19 01:36:38 +08:00
feat: plugin function is ready
This commit is contained in:
parent
d014d418e9
commit
d8ff5987dd
@ -2,6 +2,7 @@ package core
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/function"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
@ -30,9 +31,14 @@ type AppServer struct {
|
||||
ChatSession *types.LMap[string, types.ChatSession] //map[sessionId]UserId
|
||||
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
|
||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||
Functions map[string]function.Function
|
||||
}
|
||||
|
||||
func NewServer(appConfig *types.AppConfig) *AppServer {
|
||||
func NewServer(
|
||||
appConfig *types.AppConfig,
|
||||
funZaoBao function.FuncZaoBao,
|
||||
funZhiHu function.FuncHeadlines,
|
||||
funWeibo function.FuncWeiboHot) *AppServer {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
gin.DefaultWriter = io.Discard
|
||||
return &AppServer{
|
||||
@ -43,6 +49,11 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
|
||||
ChatSession: types.NewLMap[string, types.ChatSession](),
|
||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||
Functions: map[string]function.Function{
|
||||
types.FuncZaoBao: funZaoBao,
|
||||
types.FuncWeibo: funWeibo,
|
||||
types.FuncHeadLine: funZhiHu,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -22,9 +22,15 @@ type Property struct {
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
const (
|
||||
FuncZaoBao = "zao_bao" // 每日早报
|
||||
FuncHeadLine = "headline" // 今日头条
|
||||
FuncWeibo = "weibo_hot" // 微博热搜
|
||||
)
|
||||
|
||||
var InnerFunctions = []Function{
|
||||
{
|
||||
Name: "zao_bao",
|
||||
Name: FuncZaoBao,
|
||||
Description: "每日早报,获取当天全球的热门新闻事件列表",
|
||||
Parameters: Parameters{
|
||||
|
||||
@ -39,7 +45,7 @@ var InnerFunctions = []Function{
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "weibo_hot",
|
||||
Name: FuncWeibo,
|
||||
Description: "新浪微博热搜榜,微博当日热搜榜单",
|
||||
Parameters: Parameters{
|
||||
Type: "object",
|
||||
@ -54,8 +60,8 @@ var InnerFunctions = []Function{
|
||||
},
|
||||
|
||||
{
|
||||
Name: "zhihu_top",
|
||||
Description: "知乎热榜,知乎当日话题讨论榜单",
|
||||
Name: FuncHeadLine,
|
||||
Description: "今日头条,给用户推荐当天的头条新闻,周榜热文",
|
||||
Parameters: Parameters{
|
||||
Type: "object",
|
||||
Properties: map[string]Property{
|
||||
@ -68,9 +74,3 @@ var InnerFunctions = []Function{
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var FunctionNameMap = map[string]string{
|
||||
"zao_bao": "每日早报",
|
||||
"weibo_hot": "微博热搜",
|
||||
"zhihu_top": "知乎热榜",
|
||||
}
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"bytes"
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/function"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
@ -30,12 +29,11 @@ const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。"
|
||||
|
||||
type ChatHandler struct {
|
||||
BaseHandler
|
||||
db *gorm.DB
|
||||
funcZaoBao *function.FuncZaoBao
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, zaoBao *function.FuncZaoBao) *ChatHandler {
|
||||
handler := ChatHandler{db: db, funcZaoBao: zaoBao}
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
|
||||
handler := ChatHandler{db: db}
|
||||
handler.App = app
|
||||
return &handler
|
||||
}
|
||||
@ -279,8 +277,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
if !utils.IsEmptyValue(fun) {
|
||||
functionCall = true
|
||||
functionName = fun.Name
|
||||
f := h.App.Functions[functionName]
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", types.FunctionNameMap[functionName])})
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
|
||||
continue
|
||||
}
|
||||
|
||||
@ -308,8 +307,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
if functionCall { // 调用函数完成任务
|
||||
logger.Info(functionName)
|
||||
logger.Info(arguments)
|
||||
f := h.App.Functions[functionName]
|
||||
// TODO 调用函数完成任务
|
||||
data, err := h.funcZaoBao.Fetch()
|
||||
data, err := f.Invoke(arguments)
|
||||
if err != nil {
|
||||
replyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
@ -338,19 +338,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
message.Content = strings.Join(contents, "")
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var totalTokens = 0
|
||||
if functionCall { // 函数名 + 参数 token
|
||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
||||
totalTokens += tokens
|
||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||
totalTokens += tokens
|
||||
} else {
|
||||
req.Messages = append(req.Messages, message)
|
||||
totalTokens += getTotalTokens(req)
|
||||
}
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("`本轮对话共消耗 Token 数量: %d`", totalTokens)})
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if userVo.ChatConfig.EnableContext && functionCall == false {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
@ -409,9 +396,20 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 统计用户 token 数量
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var totalTokens = 0
|
||||
if functionCall { // 函数名 + 参数 token
|
||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
||||
totalTokens += tokens
|
||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||
totalTokens += tokens
|
||||
} else {
|
||||
req.Messages = append(req.Messages, message)
|
||||
totalTokens += getTotalTokens(req)
|
||||
}
|
||||
//replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("\n\n `本轮对话共消耗 Token 数量: %d`", totalTokens+11)})
|
||||
h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
|
||||
historyUserMsg.Tokens+historyReplyMsg.Tokens))
|
||||
totalTokens))
|
||||
}
|
||||
|
||||
// 保存当前会话
|
||||
|
@ -46,7 +46,7 @@ type CodeStats struct {
|
||||
// Token 生成自验证 token
|
||||
func (h *VerifyHandler) Token(c *gin.Context) {
|
||||
// 如果不是通过浏览器访问,则返回错误的 token
|
||||
if c.GetHeader("Sec-Fetch-Mode") != "cors" {
|
||||
if c.GetHeader("Sec-Invoke-Mode") != "cors" {
|
||||
token := fmt.Sprintf("%s:%d", utils.RandString(32), time.Now().Unix())
|
||||
encrypt, err := utils.AesEncrypt(h.App.Config.AesEncryptKey, []byte(token))
|
||||
if err != nil {
|
||||
|
24
api/main.go
24
api/main.go
@ -102,12 +102,26 @@ func main() {
|
||||
}),
|
||||
|
||||
// 创建函数
|
||||
fx.Provide(func() (*function.FuncZaoBao, error) {
|
||||
token := os.Getenv("AL_API_TOKEN")
|
||||
if token == "" {
|
||||
return nil, errors.New("invalid AL api token")
|
||||
fx.Provide(func() (function.FuncZaoBao, error) {
|
||||
apiToken := os.Getenv("AL_API_TOKEN")
|
||||
if apiToken == "" {
|
||||
return function.FuncZaoBao{}, errors.New("invalid AL api token")
|
||||
}
|
||||
return function.NewZaoBao(token), nil
|
||||
return function.NewZaoBao(apiToken), nil
|
||||
}),
|
||||
fx.Provide(func() (function.FuncWeiboHot, error) {
|
||||
apiToken := os.Getenv("AL_API_TOKEN")
|
||||
if apiToken == "" {
|
||||
return function.FuncWeiboHot{}, errors.New("invalid AL api token")
|
||||
}
|
||||
return function.NewWeiboHot(apiToken), nil
|
||||
}),
|
||||
fx.Provide(func() (function.FuncHeadlines, error) {
|
||||
apiToken := os.Getenv("AL_API_TOKEN")
|
||||
if apiToken == "" {
|
||||
return function.FuncHeadlines{}, errors.New("invalid AL api token")
|
||||
}
|
||||
return function.NewHeadLines(apiToken), nil
|
||||
}),
|
||||
|
||||
// 创建控制器
|
||||
|
11
api/service/function/function.go
Normal file
11
api/service/function/function.go
Normal file
@ -0,0 +1,11 @@
|
||||
package function
|
||||
|
||||
type Function interface {
|
||||
Invoke(...interface{}) (string, error)
|
||||
Name() string
|
||||
}
|
||||
|
||||
type resVo struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
}
|
60
api/service/function/tou_tiao.go
Normal file
60
api/service/function/tou_tiao.go
Normal file
@ -0,0 +1,60 @@
|
||||
package function
|
||||
|
||||
import (
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 今日头条函数实现
|
||||
|
||||
type FuncHeadlines struct {
|
||||
name string
|
||||
apiURL string
|
||||
token string
|
||||
}
|
||||
|
||||
func NewHeadLines(token string) FuncHeadlines {
|
||||
return FuncHeadlines{name: "今日头条", apiURL: "https://v2.alapi.cn/api/tophub/get", token: token}
|
||||
}
|
||||
|
||||
type HeadLineVo struct {
|
||||
resVo
|
||||
Data struct {
|
||||
Name string `json:"name"`
|
||||
LastUpdate string `json:"last_update"`
|
||||
List []struct {
|
||||
Title string `json:"title"`
|
||||
Link string `json:"link"`
|
||||
Other string `json:"other"`
|
||||
} `json:"list"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func (f FuncHeadlines) Invoke(...interface{}) (string, error) {
|
||||
|
||||
url := fmt.Sprintf("%s?type=toutiao&token=%s", f.apiURL, f.token)
|
||||
bytes, err := utils.HttpGet(url, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var res HeadLineVo
|
||||
err = utils.JsonDecode(string(bytes), &res)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if res.Code != 200 {
|
||||
return "", fmt.Errorf("call api fail: %s", res.Msg)
|
||||
}
|
||||
builder := make([]string, 0)
|
||||
builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Name, res.Data.LastUpdate))
|
||||
for i, v := range res.Data.List {
|
||||
builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [%s]", i+1, v.Title, v.Link, v.Other))
|
||||
}
|
||||
return strings.Join(builder, "\n\n"), nil
|
||||
}
|
||||
|
||||
func (f FuncHeadlines) Name() string {
|
||||
return f.name
|
||||
}
|
56
api/service/function/weibo_hot.go
Normal file
56
api/service/function/weibo_hot.go
Normal file
@ -0,0 +1,56 @@
|
||||
package function
|
||||
|
||||
import (
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 微博热搜函数实现
|
||||
|
||||
type FuncWeiboHot struct {
|
||||
name string
|
||||
apiURL string
|
||||
token string
|
||||
}
|
||||
|
||||
func NewWeiboHot(token string) FuncWeiboHot {
|
||||
return FuncWeiboHot{name: "微博热搜", apiURL: "https://v2.alapi.cn/api/new/wbtop", token: token}
|
||||
}
|
||||
|
||||
type WeiBoVo struct {
|
||||
resVo
|
||||
Data []struct {
|
||||
HotWord string `json:"hot_word"`
|
||||
HotWordNum int `json:"hot_word_num"`
|
||||
Url string `json:"url"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func (f FuncWeiboHot) Invoke(...interface{}) (string, error) {
|
||||
|
||||
url := fmt.Sprintf("%s?num=10&token=%s", f.apiURL, f.token)
|
||||
bytes, err := utils.HttpGet(url, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var res WeiBoVo
|
||||
err = utils.JsonDecode(string(bytes), &res)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if res.Code != 200 {
|
||||
return "", fmt.Errorf("call api fail: %s", res.Msg)
|
||||
}
|
||||
builder := make([]string, 0)
|
||||
builder = append(builder, "**新浪微博今日热搜:**")
|
||||
for i, v := range res.Data {
|
||||
builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%d]", i+1, v.HotWord, v.Url, v.HotWordNum))
|
||||
}
|
||||
return strings.Join(builder, "\n\n"), nil
|
||||
}
|
||||
|
||||
func (f FuncWeiboHot) Name() string {
|
||||
return f.name
|
||||
}
|
@ -9,17 +9,17 @@ import (
|
||||
// 每日早报函数实现
|
||||
|
||||
type FuncZaoBao struct {
|
||||
name string
|
||||
apiURL string
|
||||
token string
|
||||
}
|
||||
|
||||
func NewZaoBao(token string) *FuncZaoBao {
|
||||
return &FuncZaoBao{apiURL: "https://v2.alapi.cn/api/zaobao", token: token}
|
||||
func NewZaoBao(token string) FuncZaoBao {
|
||||
return FuncZaoBao{name: "每日早报", apiURL: "https://v2.alapi.cn/api/zaobao", token: token}
|
||||
}
|
||||
|
||||
type resVo struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
type ZaoBaoVo struct {
|
||||
resVo
|
||||
Data struct {
|
||||
Date string `json:"date"`
|
||||
News []string `json:"news"`
|
||||
@ -27,14 +27,14 @@ type resVo struct {
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func (f *FuncZaoBao) Fetch() (string, error) {
|
||||
func (f FuncZaoBao) Invoke(...interface{}) (string, error) {
|
||||
|
||||
url := fmt.Sprintf("%s?format=json&token=%s", f.apiURL, f.token)
|
||||
bytes, err := utils.HttpGet(url, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var res resVo
|
||||
var res ZaoBaoVo
|
||||
err = utils.JsonDecode(string(bytes), &res)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@ -49,3 +49,7 @@ func (f *FuncZaoBao) Fetch() (string, error) {
|
||||
builder = append(builder, fmt.Sprintf("%s", res.Data.WeiYu))
|
||||
return strings.Join(builder, "\n\n"), nil
|
||||
}
|
||||
|
||||
func (f FuncZaoBao) Name() string {
|
||||
return f.name
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user