package handler // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * Copyright 2023 The Geek-AI Authors. All rights reserved. // * Use of this source code is governed by a Apache-2.0 license // * that can be found in the LICENSE file. // * @Author yangjian102621@163.com // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( "errors" "fmt" "geekai/core" "geekai/core/types" "geekai/service" "geekai/service/dalle" "geekai/service/oss" "geekai/store/model" "geekai/store/vo" "geekai/utils" "geekai/utils/resp" "strings" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/imroc/req/v3" "gorm.io/gorm" ) type FunctionHandler struct { BaseHandler uploadManager *oss.UploaderManager dallService *dalle.Service userService *service.UserService } func NewFunctionHandler( server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager, dallService *dalle.Service, userService *service.UserService) *FunctionHandler { return &FunctionHandler{ BaseHandler: BaseHandler{ App: server, DB: db, }, uploadManager: manager, dallService: dallService, userService: userService, } } // RegisterRoutes 注册路由 func (h *FunctionHandler) RegisterRoutes() { group := h.App.Engine.Group("/api/function/") group.GET("list", h.List) // 需要用户授权的接口 group.POST("weibo", h.WeiBo) group.POST("zaobao", h.ZaoBao) group.POST("dalle3", h.Dall3) } type resVo struct { Code types.BizCode `json:"code"` Message string `json:"message"` Data struct { Title string `json:"title"` UpdatedAt string `json:"updated_at"` Items []dataItem `json:"items"` } `json:"data"` } type dataItem struct { Title string `json:"title"` Url string `json:"url"` Remark string `json:"remark"` } // check authorization func (h *FunctionHandler) checkAuth(c *gin.Context) error { tokenString := c.GetHeader(types.UserAuthHeader) token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(h.App.Config.Session.SecretKey), nil }) if err != nil { return fmt.Errorf("error with parse auth token: %v", err) } claims, ok := token.Claims.(jwt.MapClaims) if !ok || !token.Valid { return errors.New("token is invalid") } expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0) if expr > 0 && int64(expr) < time.Now().Unix() { return errors.New("token is expired") } return nil } // WeiBo 微博热搜 func (h *FunctionHandler) WeiBo(c *gin.Context) { if err := h.checkAuth(c); err != nil { resp.ERROR(c, err.Error()) return } url := fmt.Sprintf("%s/api/weibo/fetch", types.GeekAPIURL) var res resVo r, err := req.C().R(). SetHeader("Authorization", "Bearer geekai-plus"). SetSuccessResult(&res).Get(url) if err != nil { resp.ERROR(c, fmt.Sprintf("%v", err)) return } if r.IsErrorState() { resp.ERROR(c, fmt.Sprintf("error http code status: %v", r.Status)) } if res.Code != types.Success { resp.ERROR(c, res.Message) return } builder := make([]string, 0) builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt)) for i, v := range res.Data.Items { builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%s]", i+1, v.Title, v.Url, v.Remark)) } resp.SUCCESS(c, strings.Join(builder, "\n\n")) } // ZaoBao 今日早报 func (h *FunctionHandler) ZaoBao(c *gin.Context) { if err := h.checkAuth(c); err != nil { resp.ERROR(c, err.Error()) return } url := fmt.Sprintf("%s/api/zaobao/fetch", types.GeekAPIURL) var res resVo r, err := req.C().R(). SetHeader("Authorization", "Bearer geekai-plus"). SetSuccessResult(&res).Get(url) if err != nil { resp.ERROR(c, fmt.Sprintf("%v", err)) return } if r.IsErrorState() { resp.ERROR(c, fmt.Sprintf("%v", r.Err)) return } if res.Code != types.Success { resp.ERROR(c, res.Message) return } builder := make([]string, 0) builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.UpdatedAt)) for _, v := range res.Data.Items { builder = append(builder, v.Title) } builder = append(builder, res.Data.Title) resp.SUCCESS(c, strings.Join(builder, "\n\n")) } // Dall3 DallE3 AI 绘图 func (h *FunctionHandler) Dall3(c *gin.Context) { if err := h.checkAuth(c); err != nil { resp.ERROR(c, err.Error()) return } // var params map[string]interface{} // if err := c.ShouldBindJSON(¶ms); err != nil { // resp.ERROR(c, types.InvalidArgs) // return // } // logger.Debugf("绘画参数:%+v", params) // var user model.User // res := h.DB.Where("id = ?", params["user_id"]).First(&user) // if res.Error != nil { // resp.ERROR(c, "当前用户不存在!") // return // } // if user.Power < h.App.SysConfig.Base.DallPower { // resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足") // return // } // // create dall task // prompt := utils.InterfaceToString(params["prompt"]) // task := types.DallTask{ // UserId: user.Id, // Prompt: prompt, // ModelId: 0, // ModelName: "dall-e-3", // TranslateModelId: h.App.SysConfig.Base.AssistantModelId, // N: 1, // Quality: "standard", // Size: "1024x1024", // Style: "vivid", // Power: h.App.SysConfig.Base.DallPower, // } // job := model.DallJob{ // UserId: user.Id, // Prompt: prompt, // Power: h.App.SysConfig.Base.DallPower, // TaskInfo: utils.JsonEncode(task), // } // err := h.DB.Create(&job).Error // if err != nil { // resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+err.Error()) // return // } // task.Id = job.Id // content, err := h.dallService.Image(task, true) // if err != nil { // resp.ERROR(c, "任务执行失败:"+err.Error()) // return // } // // 扣减算力 // err = h.userService.DecreasePower(user.Id, job.Power, model.PowerLog{ // Type: types.PowerConsume, // Model: task.ModelName, // Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(job.Prompt, 10)), // }) // if err != nil { // resp.ERROR(c, "扣减算力失败:"+err.Error()) // return // } // resp.SUCCESS(c, content) // } // // 实现一个联网搜索的函数工具,采用爬虫实现 // func (h *FunctionHandler) WebSearch(c *gin.Context) { // if err := h.checkAuth(c); err != nil { // resp.ERROR(c, err.Error()) // return // } // var params map[string]interface{} // if err := c.ShouldBindJSON(¶ms); err != nil { // resp.ERROR(c, types.InvalidArgs) // return // } // // 从参数中获取搜索关键词 // keyword, ok := params["keyword"].(string) // if !ok || keyword == "" { // resp.ERROR(c, "搜索关键词不能为空") // return // } // // 从参数中获取最大页数,默认为1页 // maxPages := 1 // if pages, ok := params["max_pages"].(float64); ok { // maxPages = int(pages) // } // // 获取用户ID // userID, ok := params["user_id"].(float64) // if !ok { // resp.ERROR(c, "用户ID不能为空") // return // } // // 查询用户信息 // var user model.User // res := h.DB.Where("id = ?", int(userID)).First(&user) // if res.Error != nil { // resp.ERROR(c, "用户不存在") // return // } // // 检查用户算力是否足够 // searchPower := 1 // 每次搜索消耗1点算力 // if user.Power < searchPower { // resp.ERROR(c, "算力不足,无法执行网络搜索") // return // } // // 执行网络搜索 // searchResults, err := crawler.SearchWeb(keyword, maxPages) // if err != nil { // resp.ERROR(c, fmt.Sprintf("搜索失败: %v", err)) // return // } // // 扣减用户算力 // err = h.userService.DecreasePower(user.Id, searchPower, model.PowerLog{ // Type: types.PowerConsume, // Model: "web_search", // Remark: fmt.Sprintf("网络搜索:%s", utils.CutWords(keyword, 10)), // }) // if err != nil { // resp.ERROR(c, "扣减算力失败:"+err.Error()) // return // } // 返回搜索结果 // resp.SUCCESS(c, searchResults) } // List 获取所有的工具函数列表 func (h *FunctionHandler) List(c *gin.Context) { var items []model.Function err := h.DB.Where("enabled", true).Find(&items).Error if err != nil { resp.ERROR(c, err.Error()) return } tools := make([]vo.Function, 0) for _, v := range items { var f vo.Function err = utils.CopyObject(v, &f) if err != nil { continue } f.Action = "" f.Token = "" tools = append(tools, f) } resp.SUCCESS(c, tools) }