mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-12-18 06:06:01 +08:00
349 lines
8.3 KiB
Go
349 lines
8.3 KiB
Go
package handler
|
||
|
||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||
// * Use of this source code is governed by a Apache-2.0 license
|
||
// * that can be found in the LICENSE file.
|
||
// * @Author yangjian102621@163.com
|
||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"geekai/core"
|
||
"geekai/core/types"
|
||
"geekai/service"
|
||
"geekai/service/crawler"
|
||
"geekai/service/dalle"
|
||
"geekai/service/oss"
|
||
"geekai/store/model"
|
||
"geekai/store/vo"
|
||
"geekai/utils"
|
||
"geekai/utils/resp"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/golang-jwt/jwt/v5"
|
||
"github.com/imroc/req/v3"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type FunctionHandler struct {
|
||
BaseHandler
|
||
config types.ApiConfig
|
||
uploadManager *oss.UploaderManager
|
||
dallService *dalle.Service
|
||
userService *service.UserService
|
||
}
|
||
|
||
func NewFunctionHandler(
|
||
server *core.AppServer,
|
||
db *gorm.DB,
|
||
config *types.AppConfig,
|
||
manager *oss.UploaderManager,
|
||
dallService *dalle.Service,
|
||
userService *service.UserService) *FunctionHandler {
|
||
return &FunctionHandler{
|
||
BaseHandler: BaseHandler{
|
||
App: server,
|
||
DB: db,
|
||
},
|
||
config: config.ApiConfig,
|
||
uploadManager: manager,
|
||
dallService: dallService,
|
||
userService: userService,
|
||
}
|
||
}
|
||
|
||
type resVo struct {
|
||
Code types.BizCode `json:"code"`
|
||
Message string `json:"message"`
|
||
Data struct {
|
||
Title string `json:"title"`
|
||
UpdatedAt string `json:"updated_at"`
|
||
Items []dataItem `json:"items"`
|
||
} `json:"data"`
|
||
}
|
||
|
||
type dataItem struct {
|
||
Title string `json:"title"`
|
||
Url string `json:"url"`
|
||
Remark string `json:"remark"`
|
||
}
|
||
|
||
// check authorization
|
||
func (h *FunctionHandler) checkAuth(c *gin.Context) error {
|
||
tokenString := c.GetHeader(types.UserAuthHeader)
|
||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||
}
|
||
|
||
return []byte(h.App.Config.Session.SecretKey), nil
|
||
})
|
||
|
||
if err != nil {
|
||
return fmt.Errorf("error with parse auth token: %v", err)
|
||
}
|
||
|
||
claims, ok := token.Claims.(jwt.MapClaims)
|
||
if !ok || !token.Valid {
|
||
return errors.New("token is invalid")
|
||
}
|
||
|
||
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
||
if expr > 0 && int64(expr) < time.Now().Unix() {
|
||
return errors.New("token is expired")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// WeiBo 微博热搜
|
||
func (h *FunctionHandler) WeiBo(c *gin.Context) {
|
||
if err := h.checkAuth(c); err != nil {
|
||
resp.ERROR(c, err.Error())
|
||
return
|
||
}
|
||
|
||
if h.config.Token == "" {
|
||
resp.ERROR(c, "无效的 API Token")
|
||
return
|
||
}
|
||
|
||
url := fmt.Sprintf("%s/api/weibo/fetch", h.config.ApiURL)
|
||
var res resVo
|
||
r, err := req.C().R().
|
||
SetHeader("AppId", h.config.AppId).
|
||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
|
||
SetSuccessResult(&res).Get(url)
|
||
if err != nil {
|
||
resp.ERROR(c, fmt.Sprintf("%v", err))
|
||
return
|
||
}
|
||
if r.IsErrorState() {
|
||
resp.ERROR(c, fmt.Sprintf("error http code status: %v", r.Status))
|
||
}
|
||
|
||
if res.Code != types.Success {
|
||
resp.ERROR(c, res.Message)
|
||
return
|
||
}
|
||
|
||
builder := make([]string, 0)
|
||
builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt))
|
||
for i, v := range res.Data.Items {
|
||
builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%s]", i+1, v.Title, v.Url, v.Remark))
|
||
}
|
||
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
|
||
}
|
||
|
||
// ZaoBao 今日早报
|
||
func (h *FunctionHandler) ZaoBao(c *gin.Context) {
|
||
if err := h.checkAuth(c); err != nil {
|
||
resp.ERROR(c, err.Error())
|
||
return
|
||
}
|
||
|
||
if h.config.Token == "" {
|
||
resp.ERROR(c, "无效的 API Token")
|
||
return
|
||
}
|
||
|
||
url := fmt.Sprintf("%s/api/zaobao/fetch", h.config.ApiURL)
|
||
var res resVo
|
||
r, err := req.C().R().
|
||
SetHeader("AppId", h.config.AppId).
|
||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
|
||
SetSuccessResult(&res).Get(url)
|
||
if err != nil {
|
||
resp.ERROR(c, fmt.Sprintf("%v", err))
|
||
return
|
||
}
|
||
if r.IsErrorState() {
|
||
resp.ERROR(c, fmt.Sprintf("%v", r.Err))
|
||
return
|
||
}
|
||
|
||
if res.Code != types.Success {
|
||
resp.ERROR(c, res.Message)
|
||
return
|
||
}
|
||
|
||
builder := make([]string, 0)
|
||
builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.UpdatedAt))
|
||
for _, v := range res.Data.Items {
|
||
builder = append(builder, v.Title)
|
||
}
|
||
builder = append(builder, res.Data.Title)
|
||
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
|
||
}
|
||
|
||
// Dall3 DallE3 AI 绘图
|
||
func (h *FunctionHandler) Dall3(c *gin.Context) {
|
||
if err := h.checkAuth(c); err != nil {
|
||
resp.ERROR(c, err.Error())
|
||
return
|
||
}
|
||
|
||
var params map[string]interface{}
|
||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||
resp.ERROR(c, types.InvalidArgs)
|
||
return
|
||
}
|
||
|
||
logger.Debugf("绘画参数:%+v", params)
|
||
var user model.User
|
||
res := h.DB.Where("id = ?", params["user_id"]).First(&user)
|
||
if res.Error != nil {
|
||
resp.ERROR(c, "当前用户不存在!")
|
||
return
|
||
}
|
||
|
||
if user.Power < h.App.SysConfig.DallPower {
|
||
resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足")
|
||
return
|
||
}
|
||
|
||
// create dall task
|
||
prompt := utils.InterfaceToString(params["prompt"])
|
||
task := types.DallTask{
|
||
UserId: user.Id,
|
||
Prompt: prompt,
|
||
ModelId: 0,
|
||
ModelName: "dall-e-3",
|
||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||
N: 1,
|
||
Quality: "standard",
|
||
Size: "1024x1024",
|
||
Style: "vivid",
|
||
Power: h.App.SysConfig.DallPower,
|
||
}
|
||
job := model.DallJob{
|
||
UserId: user.Id,
|
||
Prompt: prompt,
|
||
Power: h.App.SysConfig.DallPower,
|
||
TaskInfo: utils.JsonEncode(task),
|
||
}
|
||
err := h.DB.Create(&job).Error
|
||
if err != nil {
|
||
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+err.Error())
|
||
return
|
||
}
|
||
|
||
task.Id = job.Id
|
||
content, err := h.dallService.Image(task, true)
|
||
if err != nil {
|
||
resp.ERROR(c, "任务执行失败:"+err.Error())
|
||
return
|
||
}
|
||
|
||
// 扣减算力
|
||
err = h.userService.DecreasePower(int(user.Id), job.Power, model.PowerLog{
|
||
Type: types.PowerConsume,
|
||
Model: task.ModelName,
|
||
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(job.Prompt, 10)),
|
||
})
|
||
if err != nil {
|
||
resp.ERROR(c, "扣减算力失败:"+err.Error())
|
||
return
|
||
}
|
||
|
||
resp.SUCCESS(c, content)
|
||
}
|
||
|
||
// 实现一个联网搜索的函数工具,采用爬虫实现
|
||
func (h *FunctionHandler) WebSearch(c *gin.Context) {
|
||
if err := h.checkAuth(c); err != nil {
|
||
resp.ERROR(c, err.Error())
|
||
return
|
||
}
|
||
|
||
var params map[string]interface{}
|
||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||
resp.ERROR(c, types.InvalidArgs)
|
||
return
|
||
}
|
||
|
||
// 从参数中获取搜索关键词
|
||
keyword, ok := params["keyword"].(string)
|
||
if !ok || keyword == "" {
|
||
resp.ERROR(c, "搜索关键词不能为空")
|
||
return
|
||
}
|
||
|
||
// 从参数中获取最大页数,默认为1页
|
||
maxPages := 1
|
||
if pages, ok := params["max_pages"].(float64); ok {
|
||
maxPages = int(pages)
|
||
}
|
||
|
||
// 获取用户ID
|
||
userID, ok := params["user_id"].(float64)
|
||
if !ok {
|
||
resp.ERROR(c, "用户ID不能为空")
|
||
return
|
||
}
|
||
|
||
// 查询用户信息
|
||
var user model.User
|
||
res := h.DB.Where("id = ?", int(userID)).First(&user)
|
||
if res.Error != nil {
|
||
resp.ERROR(c, "用户不存在")
|
||
return
|
||
}
|
||
|
||
// 检查用户算力是否足够
|
||
searchPower := 1 // 每次搜索消耗1点算力
|
||
if user.Power < searchPower {
|
||
resp.ERROR(c, "算力不足,无法执行网络搜索")
|
||
return
|
||
}
|
||
|
||
// 执行网络搜索
|
||
searchResults, err := crawler.SearchWeb(keyword, maxPages)
|
||
if err != nil {
|
||
resp.ERROR(c, fmt.Sprintf("搜索失败: %v", err))
|
||
return
|
||
}
|
||
|
||
// 扣减用户算力
|
||
err = h.userService.DecreasePower(int(user.Id), searchPower, model.PowerLog{
|
||
Type: types.PowerConsume,
|
||
Model: "web_search",
|
||
Remark: fmt.Sprintf("网络搜索:%s", utils.CutWords(keyword, 10)),
|
||
})
|
||
if err != nil {
|
||
resp.ERROR(c, "扣减算力失败:"+err.Error())
|
||
return
|
||
}
|
||
|
||
// 返回搜索结果
|
||
resp.SUCCESS(c, searchResults)
|
||
}
|
||
|
||
// List 获取所有的工具函数列表
|
||
func (h *FunctionHandler) List(c *gin.Context) {
|
||
var items []model.Function
|
||
err := h.DB.Where("enabled", true).Find(&items).Error
|
||
if err != nil {
|
||
resp.ERROR(c, err.Error())
|
||
return
|
||
}
|
||
|
||
tools := make([]vo.Function, 0)
|
||
for _, v := range items {
|
||
var f vo.Function
|
||
err = utils.CopyObject(v, &f)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
f.Action = ""
|
||
f.Token = ""
|
||
tools = append(tools, f)
|
||
}
|
||
|
||
resp.SUCCESS(c, tools)
|
||
}
|