diff --git a/api/core/app_server.go b/api/core/app_server.go index 887aad25..ef3032ae 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -148,6 +148,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { c.Request.URL.Path == "/api/chat/detail" || c.Request.URL.Path == "/api/role/list" || c.Request.URL.Path == "/api/mj/jobs" || + c.Request.URL.Path == "/api/invite/hits" || c.Request.URL.Path == "/api/sd/jobs" || strings.HasPrefix(c.Request.URL.Path, "/api/sms/") || strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") || diff --git a/api/core/types/config.go b/api/core/types/config.go index 5023a413..8e8f1094 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -145,7 +145,6 @@ type SystemConfig struct { VipMonthCalls int `json:"vip_month_calls"` // 会员每个赠送的调用次数 EnabledRegister bool `json:"enabled_register"` // 是否启用注册功能,关闭注册功能之后将无法注册 EnabledMsg bool `json:"enabled_msg"` // 是否启用短信验证码服务 - EnabledDraw bool `json:"enabled_draw"` // 是否启用 AI 绘画功能 RewardImg string `json:"reward_img"` // 众筹收款二维码地址 EnabledFunction bool `json:"enabled_function"` // 启用 API 函数功能 EnabledReward bool `json:"enabled_reward"` // 启用众筹功能 diff --git a/api/core/types/function.go b/api/core/types/function.go index 7b0a74db..f126d659 100644 --- a/api/core/types/function.go +++ b/api/core/types/function.go @@ -23,10 +23,10 @@ type Property struct { } const ( - FuncZaoBao = "zao_bao" // 每日早报 - FuncHeadLine = "headline" // 今日头条 - FuncWeibo = "weibo_hot" // 微博热搜 - FuncMidJourney = "mid_journey" // MJ 绘画 + FuncZaoBao = "zao_bao" // 每日早报 + FuncHeadLine = "headline" // 今日头条 + FuncWeibo = "weibo_hot" // 微博热搜 + FuncImage = "draw_image" // AI 绘画 ) var InnerFunctions = []Function{ @@ -76,14 +76,14 @@ var InnerFunctions = []Function{ }, { - Name: FuncMidJourney, - Description: "AI 绘画工具,使用 MJ MidJourney API 进行 AI 绘画", + Name: FuncImage, + Description: "AI 绘画工具,根据输入的绘图描述用 AI 工具进行绘画", Parameters: Parameters{ Type: "object", Properties: map[string]Property{ "prompt": { Type: "string", - Description: "提示词,如果该参数中有中文的话,则需要翻译成英文。提示词中的参数作为提示的一部分,不要删除", + Description: "提示词,如果该参数中有中文的话,则需要翻译成英文。", }, }, Required: []string{}, diff --git a/api/handler/admin/api_key_handler.go b/api/handler/admin/api_key_handler.go index 50e24cdb..5c8d1a45 100644 --- a/api/handler/admin/api_key_handler.go +++ b/api/handler/admin/api_key_handler.go @@ -27,6 +27,7 @@ func (h *ApiKeyHandler) Save(c *gin.Context) { var data struct { Id uint `json:"id"` Platform string `json:"platform"` + Type string `json:"type"` Value string `json:"value"` } if err := c.ShouldBindJSON(&data); err != nil { @@ -40,7 +41,8 @@ func (h *ApiKeyHandler) Save(c *gin.Context) { } apiKey.Platform = data.Platform apiKey.Value = data.Value - res := h.db.Debug().Save(&apiKey) + apiKey.Type = data.Type + res := h.db.Save(&apiKey) if res.Error != nil { resp.ERROR(c, "更新数据库失败!") return diff --git a/api/handler/chatimpl/azure_handler.go b/api/handler/chatimpl/azure_handler.go index e800544d..a4138616 100644 --- a/api/handler/chatimpl/azure_handler.go +++ b/api/handler/chatimpl/azure_handler.go @@ -127,12 +127,12 @@ func (h *ChatHandler) sendAzureMessage( logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params) // for creating image, check if the user's img_calls > 0 - if functionName == types.FuncMidJourney && userVo.ImgCalls <= 0 { + if functionName == types.FuncImage && userVo.ImgCalls <= 0 { utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**") utils.ReplyMessage(ws, ErrImg) } else { f := h.App.Functions[functionName] - if functionName == types.FuncMidJourney { + if functionName == types.FuncImage { params["user_id"] = userVo.Id params["role_id"] = role.Id params["chat_id"] = session.ChatId @@ -149,9 +149,8 @@ func (h *ChatHandler) sendAzureMessage( contents = append(contents, msg) } else { content := data - if functionName == types.FuncMidJourney { - content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data) - h.mjService.ChatClients.Put(session.SessionId, ws) + if functionName == types.FuncImage { + content = fmt.Sprintf("下面是根据您的描述创作的图片,他们描绘了 【%s】 的场景", params["prompt"]) // update user's img_calls h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) } diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index a310f21c..1a25fea3 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -224,9 +224,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio if h.App.SysConfig.EnabledFunction { var functions = make([]types.Function, 0) for _, f := range types.InnerFunctions { - if !h.App.SysConfig.EnabledDraw && f.Name == types.FuncMidJourney { - continue - } functions = append(functions, f) } req.Functions = functions @@ -405,7 +402,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf } if *apiKey == "" { var key model.ApiKey - res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key) + res := h.db.Where("platform = ? AND type = ?", platform, "chat").Order("last_used_at ASC").First(&key) if res.Error != nil { return nil, errors.New("no available key, please import key") } diff --git a/api/handler/chatimpl/openai_handler.go b/api/handler/chatimpl/openai_handler.go index 69a2b85e..cda0836c 100644 --- a/api/handler/chatimpl/openai_handler.go +++ b/api/handler/chatimpl/openai_handler.go @@ -126,12 +126,12 @@ func (h *ChatHandler) sendOpenAiMessage( logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params) // for creating image, check if the user's img_calls > 0 - if functionName == types.FuncMidJourney && userVo.ImgCalls <= 0 { + if functionName == types.FuncImage && userVo.ImgCalls <= 0 { utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**") utils.ReplyMessage(ws, ErrImg) } else { f := h.App.Functions[functionName] - if functionName == types.FuncMidJourney { + if functionName == types.FuncImage { params["user_id"] = userVo.Id params["role_id"] = role.Id params["chat_id"] = session.ChatId @@ -148,9 +148,8 @@ func (h *ChatHandler) sendOpenAiMessage( contents = append(contents, msg) } else { content := data - if functionName == types.FuncMidJourney { - content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data) - h.mjService.ChatClients.Put(session.SessionId, ws) + if functionName == types.FuncImage { + content = fmt.Sprintf("下面是根据您的描述创作的图片,他们描绘了 【%s】 的场景。%s", params["prompt"], data) // update user's img_calls h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) } diff --git a/api/handler/chatimpl/xunfei_handler.go b/api/handler/chatimpl/xunfei_handler.go index 8bbc3084..2221d9e9 100644 --- a/api/handler/chatimpl/xunfei_handler.go +++ b/api/handler/chatimpl/xunfei_handler.go @@ -69,7 +69,7 @@ func (h *ChatHandler) sendXunFeiMessage( var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] if apiKey == "" { var key model.ApiKey - res := h.db.Where("platform = ?", session.Model.Platform).Order("last_used_at ASC").First(&key) + res := h.db.Where("platform = ? AND type = ?", session.Model.Platform, "chat").Order("last_used_at ASC").First(&key) if res.Error != nil { utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") return nil diff --git a/api/service/fun/func_img.go b/api/service/fun/func_img.go new file mode 100644 index 00000000..004251d6 --- /dev/null +++ b/api/service/fun/func_img.go @@ -0,0 +1,92 @@ +package fun + +import ( + "chatplus/core/types" + "chatplus/service/oss" + "chatplus/store/model" + "chatplus/utils" + "fmt" + "github.com/imroc/req/v3" + "gorm.io/gorm" +) + +// AI 绘画函数 + +type FuncImage struct { + name string + apiURL string + db *gorm.DB + uploadManager *oss.UploaderManager + proxyURL string +} + +func NewImageFunc(db *gorm.DB, manager *oss.UploaderManager, config *types.AppConfig) FuncImage { + return FuncImage{ + db: db, + name: "DALL-E3 绘画", + uploadManager: manager, + proxyURL: config.ProxyURL, + apiURL: "https://api.openai.com/v1/images/generations", + } +} + +type imgReq struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + N int `json:"n"` + Size string `json:"size"` +} + +type imgRes struct { + Created int64 `json:"created"` + Data []struct { + RevisedPrompt string `json:"revised_prompt"` + Url string `json:"url"` + } `json:"data"` +} + +type ErrRes struct { + Error struct { + Code interface{} `json:"code"` + Message string `json:"message"` + Param interface{} `json:"param"` + Type string `json:"type"` + } `json:"error"` +} + +func (f FuncImage) Invoke(params map[string]interface{}) (string, error) { + logger.Infof("绘画参数:%+v", params) + prompt := utils.InterfaceToString(params["prompt"]) + // 获取绘图 API KEY + var apiKey model.ApiKey + f.db.Where("platform = ? AND type = ?", types.OpenAI, "img").Order("last_used_at ASC").First(&apiKey) + var res imgRes + var errRes ErrRes + r, err := req.C().SetProxyURL(f.proxyURL).R().SetHeader("Content-Type", "application/json"). + SetHeader("Authorization", "Bearer "+apiKey.Value). + SetBody(imgReq{ + Model: "dall-e-3", + Prompt: prompt, + N: 1, + Size: "1024x1024", + }). + SetErrorResult(&errRes). + SetSuccessResult(&res).Post(f.apiURL) + if err != nil || r.IsErrorState() { + return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message) + } + // 存储图片 + imgURL, err := f.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false) + if err != nil { + return "", fmt.Errorf("下载图片失败: %s", err.Error()) + } + + logger.Info(imgURL) + return fmt.Sprintf("\n\n\n", imgURL), nil +} + +func (f FuncImage) Name() string { + return f.name +} + +var _ Function = &FuncImage{} diff --git a/api/service/fun/func_mj.go b/api/service/fun/func_mj.go deleted file mode 100644 index bbd83554..00000000 --- a/api/service/fun/func_mj.go +++ /dev/null @@ -1,49 +0,0 @@ -package fun - -import ( - "chatplus/core/types" - "chatplus/service/mj" - "chatplus/utils" - "errors" -) - -// AI 绘画函数 - -type FuncMidJourney struct { - name string - service *mj.Service - config types.MidJourneyConfig -} - -func NewMidJourneyFunc(mjService *mj.Service, config types.MidJourneyConfig) FuncMidJourney { - return FuncMidJourney{ - name: "MidJourney AI 绘画", - config: config, - service: mjService} -} - -func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { - if !f.config.Enabled { - return "", errors.New("MidJourney AI 绘画功能没有启用") - } - - logger.Infof("MJ 绘画参数:%+v", params) - prompt := utils.InterfaceToString(params["prompt"]) - f.service.PushTask(types.MjTask{ - SessionId: utils.InterfaceToString(params["session_id"]), - Src: types.TaskSrcChat, - Type: types.TaskImage, - Prompt: prompt, - UserId: utils.IntValue(utils.InterfaceToString(params["user_id"]), 0), - RoleId: utils.IntValue(utils.InterfaceToString(params["role_id"]), 0), - Icon: utils.InterfaceToString(params["icon"]), - ChatId: utils.InterfaceToString(params["chat_id"]), - }) - return prompt, nil -} - -func (f FuncMidJourney) Name() string { - return f.name -} - -var _ Function = &FuncMidJourney{} diff --git a/api/service/fun/function.go b/api/service/fun/function.go index 8b33dcbb..ac75de60 100644 --- a/api/service/fun/function.go +++ b/api/service/fun/function.go @@ -3,7 +3,8 @@ package fun import ( "chatplus/core/types" logger2 "chatplus/logger" - "chatplus/service/mj" + "chatplus/service/oss" + "gorm.io/gorm" ) type Function interface { @@ -29,11 +30,11 @@ type dataItem struct { Remark string `json:"remark"` } -func NewFunctions(config *types.AppConfig, mjService *mj.Service) map[string]Function { +func NewFunctions(config *types.AppConfig, db *gorm.DB, manager *oss.UploaderManager) map[string]Function { return map[string]Function{ - types.FuncZaoBao: NewZaoBao(config.ApiConfig), - types.FuncWeibo: NewWeiboHot(config.ApiConfig), - types.FuncHeadLine: NewHeadLines(config.ApiConfig), - types.FuncMidJourney: NewMidJourneyFunc(mjService, config.MjConfig), + types.FuncZaoBao: NewZaoBao(config.ApiConfig), + types.FuncWeibo: NewWeiboHot(config.ApiConfig), + types.FuncHeadLine: NewHeadLines(config.ApiConfig), + types.FuncImage: NewImageFunc(db, manager, config), } } diff --git a/api/test/test.go b/api/test/test.go index abb09f91..90ee566f 100644 --- a/api/test/test.go +++ b/api/test/test.go @@ -1,13 +1,23 @@ package main import ( + "chatplus/utils" "fmt" - "reflect" - "time" + "os" ) func main() { - r := time.Now() - f := reflect.ValueOf(r) - fmt.Println(f.Type().Kind()) + imgURL := "https://oaidalleapiprodscus.blob.core.windows.net/private/org-UJimNEKhVm07E58nxnjx5FeG/user-e5UAcPVbkm2nwD8urggRRM8q/img-zFXWyrJ9Z1HppI36dZMXNEaA.png?st=2023-11-26T09%3A57%3A49Z&se=2023-11-26T11%3A57%3A49Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2023-11-26T08%3A14%3A59Z&ske=2023-11-27T08%3A14%3A59Z&sks=b&skv=2021-08-06&sig=VmlU9didavbl02XYim2XuMmLMFJsLtCY/ULnzCjeO1g%3D" + imageData, err := utils.DownloadImage(imgURL, "") + if err != nil { + panic(err) + } + newImagePath := "newimage.png" + err = os.WriteFile(newImagePath, imageData, 0644) + if err != nil { + fmt.Println("Error writing image file:", err) + return + } + + fmt.Println("图片保存成功!") } diff --git a/web/src/views/Home.vue b/web/src/views/Home.vue index 031f5d73..d89af195 100644 --- a/web/src/views/Home.vue +++ b/web/src/views/Home.vue @@ -2,10 +2,7 @@