package handler // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * Copyright 2023 The Geek-AI Authors. All rights reserved. // * Use of this source code is governed by a Apache-2.0 license // * that can be found in the LICENSE file. // * @Author yangjian102621@163.com // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( "errors" "fmt" "geekai/core" "geekai/core/types" "geekai/service" "geekai/service/dalle" "geekai/service/oss" "geekai/store/model" "geekai/store/vo" "geekai/utils" "geekai/utils/resp" "strings" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/imroc/req/v3" "gorm.io/gorm" ) type FunctionHandler struct { BaseHandler uploadManager *oss.UploaderManager dallService *dalle.Service userService *service.UserService } func NewFunctionHandler( server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager, dallService *dalle.Service, userService *service.UserService) *FunctionHandler { return &FunctionHandler{ BaseHandler: BaseHandler{ App: server, DB: db, }, uploadManager: manager, dallService: dallService, userService: userService, } } // RegisterRoutes 注册路由 func (h *FunctionHandler) RegisterRoutes() { group := h.App.Engine.Group("/api/function/") group.GET("list", h.List) // 需要用户授权的接口 group.POST("weibo", h.WeiBo) group.POST("zaobao", h.ZaoBao) group.POST("dalle3", h.Dall3) } type resVo struct { Code types.BizCode `json:"code"` Message string `json:"message"` Data struct { Title string `json:"title"` UpdatedAt string `json:"updated_at"` Items []dataItem `json:"items"` } `json:"data"` } type dataItem struct { Title string `json:"title"` Url string `json:"url"` Remark string `json:"remark"` } // check authorization func (h *FunctionHandler) checkAuth(c *gin.Context) error { tokenString := c.GetHeader(types.UserAuthHeader) token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(h.App.Config.Session.SecretKey), nil }) if err != nil { return fmt.Errorf("error with parse auth token: %v", err) } claims, ok := token.Claims.(jwt.MapClaims) if !ok || !token.Valid { return errors.New("token is invalid") } expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0) if expr > 0 && int64(expr) < time.Now().Unix() { return errors.New("token is expired") } return nil } // WeiBo 微博热搜 func (h *FunctionHandler) WeiBo(c *gin.Context) { if err := h.checkAuth(c); err != nil { resp.ERROR(c, err.Error()) return } url := fmt.Sprintf("%s/api/weibo/fetch", types.GeekAPIURL) var res resVo r, err := req.C().R(). SetHeader("Authorization", "Bearer geekai-plus"). SetSuccessResult(&res).Get(url) if err != nil { resp.ERROR(c, fmt.Sprintf("%v", err)) return } if r.IsErrorState() { resp.ERROR(c, fmt.Sprintf("error http code status: %v", r.Status)) } if res.Code != types.Success { resp.ERROR(c, res.Message) return } builder := make([]string, 0) builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt)) for i, v := range res.Data.Items { builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%s]", i+1, v.Title, v.Url, v.Remark)) } resp.SUCCESS(c, strings.Join(builder, "\n\n")) } // ZaoBao 今日早报 func (h *FunctionHandler) ZaoBao(c *gin.Context) { if err := h.checkAuth(c); err != nil { resp.ERROR(c, err.Error()) return } url := fmt.Sprintf("%s/api/zaobao/fetch", types.GeekAPIURL) var res resVo r, err := req.C().R(). SetHeader("Authorization", "Bearer geekai-plus"). SetSuccessResult(&res).Get(url) if err != nil { resp.ERROR(c, fmt.Sprintf("%v", err)) return } if r.IsErrorState() { resp.ERROR(c, fmt.Sprintf("%v", r.Err)) return } if res.Code != types.Success { resp.ERROR(c, res.Message) return } builder := make([]string, 0) builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.UpdatedAt)) for _, v := range res.Data.Items { builder = append(builder, v.Title) } builder = append(builder, res.Data.Title) resp.SUCCESS(c, strings.Join(builder, "\n\n")) } // Dall3 DallE3 AI 绘图 func (h *FunctionHandler) Dall3(c *gin.Context) { if err := h.checkAuth(c); err != nil { resp.ERROR(c, err.Error()) return } var params map[string]interface{} if err := c.ShouldBindJSON(¶ms); err != nil { resp.ERROR(c, types.InvalidArgs) return } var chatModel model.ChatModel res := h.DB.Where("type = ?", "img").Where("enabled", true).First(&chatModel) if res.Error != nil { resp.ERROR(c, "没有找到可用的AI绘图模型!") return } logger.Debugf("绘画参数:%+v", params) var user model.User res = h.DB.Where("id = ?", params["user_id"]).First(&user) if res.Error != nil { resp.ERROR(c, "当前用户不存在!") return } if user.Power < chatModel.Power { resp.ERROR(c, "创建绘图任务失败,算力不足") return } // create dall task prompt := utils.InterfaceToString(params["prompt"]) task := types.DallTask{ UserId: user.Id, Prompt: prompt, ModelId: chatModel.Id, ModelName: chatModel.Value, TranslateModelId: h.App.SysConfig.Base.AssistantModelId, N: 1, Quality: "standard", Size: "1024x1024", Style: "vivid", Power: chatModel.Power, } job := model.DallJob{ UserId: user.Id, Prompt: prompt, Power: chatModel.Power, TaskInfo: utils.JsonEncode(task), } err := h.DB.Create(&job).Error if err != nil { resp.ERROR(c, "创建绘图任务失败:"+err.Error()) return } task.Id = job.Id content, err := h.dallService.Image(task, true) if err != nil { resp.ERROR(c, "任务执行失败:"+err.Error()) return } // 扣减算力 err = h.userService.DecreasePower(user.Id, job.Power, model.PowerLog{ Type: types.PowerConsume, Model: task.ModelName, Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(job.Prompt, 10)), }) if err != nil { resp.ERROR(c, "扣减算力失败:"+err.Error()) return } resp.SUCCESS(c, content) } // List 获取所有的工具函数列表 func (h *FunctionHandler) List(c *gin.Context) { var items []model.Function err := h.DB.Where("enabled", true).Find(&items).Error if err != nil { resp.ERROR(c, err.Error()) return } tools := make([]vo.Function, 0) for _, v := range items { var f vo.Function err = utils.CopyObject(v, &f) if err != nil { continue } f.Action = "" f.Token = "" tools = append(tools, f) } resp.SUCCESS(c, tools) }