diff --git a/CHANGELOG.md b/CHANGELOG.md index fa865cd2..e5dee4d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## v4.0.4 * Bug修复:修复统一千问第二句不回复的问题 +* 功能优化:MJ 和 SD 任务正在执行时不更新已完成任务列表 * 功能新增:Dalle AI 绘画功能实现 ## v4.0.3 diff --git a/api/core/app_server.go b/api/core/app_server.go index 5c9d2ad6..b61b47ab 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -8,6 +8,7 @@ import ( "chatplus/utils/resp" "context" "fmt" + "github.com/chai2010/webp" "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" "github.com/golang-jwt/jwt/v5" @@ -16,7 +17,6 @@ import ( "image" "image/jpeg" "io" - "log" "net/http" "os" "runtime/debug" @@ -215,6 +215,8 @@ func needLogin(c *gin.Context) bool { c.Request.URL.Path == "/api/invite/hits" || c.Request.URL.Path == "/api/sd/imgWall" || c.Request.URL.Path == "/api/sd/client" || + c.Request.URL.Path == "/api/dall/imgWall" || + c.Request.URL.Path == "/api/dall/client" || c.Request.URL.Path == "/api/config/get" || c.Request.URL.Path == "/api/product/list" || c.Request.URL.Path == "/api/menu/list" || @@ -328,6 +330,10 @@ func staticResourceMiddleware() gin.HandlerFunc { // 解码图片 img, _, err := image.Decode(file) + // for .webp image + if err != nil { + img, err = webp.Decode(file) + } if err != nil { c.String(http.StatusInternalServerError, "Error decoding image") return @@ -344,7 +350,9 @@ func staticResourceMiddleware() gin.HandlerFunc { var buffer bytes.Buffer err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality}) if err != nil { - log.Fatal(err) + logger.Error(err) + c.String(http.StatusInternalServerError, err.Error()) + return } // 设置图片缓存有效期为一年 (365天) diff --git a/api/core/types/task.go b/api/core/types/task.go index cd4b516e..bb1f7689 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -59,3 +59,16 @@ type SdTaskParams struct { HdScaleAlg string `json:"hd_scale_alg"` // 放大算法 HdSteps int `json:"hd_steps"` // 高清修复迭代步数 } + +// DallTask DALL-E task +type DallTask struct { + JobId uint `json:"job_id"` + UserId uint `json:"user_id"` + Prompt string `json:"prompt"` + N int `json:"n"` + Quality string `json:"quality"` + Size string `json:"size"` + Style string `json:"style"` + + Power int `json:"power"` +} diff --git a/api/go.mod b/api/go.mod index fc131837..3e611305 100644 --- a/api/go.mod +++ b/api/go.mod @@ -32,9 +32,10 @@ require ( ) require ( + github.com/chai2010/webp v1.1.1 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect - golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 // indirect + golang.org/x/image v0.0.0-20211028202545-6944b10bf410 // indirect ) require ( diff --git a/api/go.sum b/api/go.sum index e5c987ce..06f881eb 100644 --- a/api/go.sum +++ b/api/go.sum @@ -12,6 +12,8 @@ github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk= +github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= @@ -241,6 +243,8 @@ golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERs golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ= golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= +golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ= +golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= diff --git a/api/handler/chatimpl/openai_handler.go b/api/handler/chatimpl/openai_handler.go index c991f670..1e4d2f78 100644 --- a/api/handler/chatimpl/openai_handler.go +++ b/api/handler/chatimpl/openai_handler.go @@ -104,8 +104,10 @@ func (h *ChatHandler) sendOpenAiMessage( res := h.DB.Where("name = ?", tool.Function.Name).First(&function) if res.Error == nil { toolCall = true + callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label) utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)}) + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg}) + contents = append(contents, callMsg) } continue } diff --git a/api/handler/dalle_handler.go b/api/handler/dalle_handler.go new file mode 100644 index 00000000..5a738a7e --- /dev/null +++ b/api/handler/dalle_handler.go @@ -0,0 +1,260 @@ +package handler + +import ( + "chatplus/core" + "chatplus/core/types" + "chatplus/service/dalle" + "chatplus/service/oss" + "chatplus/store/model" + "chatplus/store/vo" + "chatplus/utils" + "chatplus/utils/resp" + "net/http" + "time" + + "github.com/gorilla/websocket" + + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" + "gorm.io/gorm" +) + +type DallJobHandler struct { + BaseHandler + redis *redis.Client + service *dalle.Service + uploader *oss.UploaderManager +} + +func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler { + return &DallJobHandler{ + service: service, + uploader: manager, + BaseHandler: BaseHandler{ + App: app, + DB: db, + }, + } +} + +// Client WebSocket 客户端,用于通知任务状态变更 +func (h *DallJobHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + c.Abort() + return + } + + userId := h.GetInt(c, "user_id", 0) + if userId == 0 { + logger.Info("Invalid user ID") + c.Abort() + return + } + + client := types.NewWsClient(ws) + h.service.Clients.Put(uint(userId), client) + logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) + go func() { + for { + _, msg, err := client.Receive() + if err != nil { + client.Close() + h.service.Clients.Delete(uint(userId)) + return + } + + var message types.WsMessage + err = utils.JsonDecode(string(msg), &message) + if err != nil { + continue + } + + // 心跳消息 + if message.Type == "heartbeat" { + logger.Debug("收到 DallE 心跳消息:", message.Content) + continue + } + } + }() +} + +func (h *DallJobHandler) preCheck(c *gin.Context) bool { + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return false + } + + if user.Power < h.App.SysConfig.SdPower { + resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") + return false + } + + return true + +} + +// Image 创建一个绘画任务 +func (h *DallJobHandler) Image(c *gin.Context) { + if !h.preCheck(c) { + return + } + + var data types.DallTask + if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" { + resp.ERROR(c, types.InvalidArgs) + return + } + + idValue, _ := c.Get(types.LoginUserID) + userId := utils.IntValue(utils.InterfaceToString(idValue), 0) + job := model.DallJob{ + UserId: uint(userId), + Prompt: data.Prompt, + Power: h.App.SysConfig.DallPower, + } + res := h.DB.Create(&job) + if res.Error != nil { + resp.ERROR(c, "error with save job: "+res.Error.Error()) + return + } + + h.service.PushTask(types.DallTask{ + JobId: job.Id, + UserId: uint(userId), + Prompt: data.Prompt, + Quality: data.Quality, + Size: data.Size, + Style: data.Style, + Power: job.Power, + }) + + client := h.service.Clients.Get(job.UserId) + if client != nil { + _ = client.Send([]byte("Task Updated")) + } + resp.SUCCESS(c) +} + +// ImgWall 照片墙 +func (h *DallJobHandler) ImgWall(c *gin.Context) { + page := h.GetInt(c, "page", 0) + pageSize := h.GetInt(c, "page_size", 0) + err, jobs := h.getData(true, 0, page, pageSize, true) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取 SD 任务列表 +func (h *DallJobHandler) JobList(c *gin.Context) { + status := h.GetBool(c, "status") + userId := h.GetLoginUserId(c) + page := h.GetInt(c, "page", 0) + pageSize := h.GetInt(c, "page_size", 0) + publish := h.GetBool(c, "publish") + + err, jobs := h.getData(status, userId, page, pageSize, publish) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取任务列表 +func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) { + + session := h.DB.Session(&gorm.Session{}) + if finish { + session = session.Where("progress = ?", 100).Order("id DESC") + } else { + session = session.Where("progress < ?", 100).Order("id ASC") + } + if userId > 0 { + session = session.Where("user_id = ?", userId) + } + if publish { + session = session.Where("publish", publish) + } + if page > 0 && pageSize > 0 { + offset := (page - 1) * pageSize + session = session.Offset(offset).Limit(pageSize) + } + + var items []model.DallJob + res := session.Find(&items) + if res.Error != nil { + return res.Error, nil + } + + var jobs = make([]vo.DallJob, 0) + for _, item := range items { + // delete failed or timeout tasks + if (item.Progress < 100 && time.Now().Sub(item.CreatedAt) > time.Minute*5) || item.Progress == -1 { + h.DB.Delete(&item) + } + var job vo.DallJob + err := utils.CopyObject(item, &job) + if err != nil { + continue + } + jobs = append(jobs, job) + } + + return nil, jobs +} + +// Remove remove task image +func (h *DallJobHandler) Remove(c *gin.Context) { + var data struct { + Id uint `json:"id"` + UserId uint `json:"user_id"` + ImgURL string `json:"img_url"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + // remove job recode + res := h.DB.Delete(&model.DallJob{Id: data.Id}) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + + // remove image + err := h.uploader.GetUploadHandler().Delete(data.ImgURL) + if err != nil { + logger.Error("remove image failed: ", err) + } + + resp.SUCCESS(c) +} + +// Publish 发布/取消发布图片到画廊显示 +func (h *DallJobHandler) Publish(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享 + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Model(&model.DallJob{Id: data.Id}).UpdateColumn("publish", true) + if res.Error != nil { + resp.ERROR(c, "更新数据库失败") + return + } + + resp.SUCCESS(c) +} diff --git a/api/handler/function_handler.go b/api/handler/function_handler.go index e9eb57df..9ef45cc5 100644 --- a/api/handler/function_handler.go +++ b/api/handler/function_handler.go @@ -3,27 +3,35 @@ package handler import ( "chatplus/core" "chatplus/core/types" + "chatplus/service/dalle" "chatplus/service/oss" "chatplus/store/model" "chatplus/utils" "chatplus/utils/resp" "errors" "fmt" + "strings" + "time" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/imroc/req/v3" "gorm.io/gorm" - "strings" - "time" ) type FunctionHandler struct { BaseHandler config types.ChatPlusApiConfig uploadManager *oss.UploaderManager + dallService *dalle.Service } -func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler { +func NewFunctionHandler( + server *core.AppServer, + db *gorm.DB, + config *types.AppConfig, + manager *oss.UploaderManager, + dallService *dalle.Service) *FunctionHandler { return &FunctionHandler{ BaseHandler: BaseHandler{ App: server, @@ -31,6 +39,7 @@ func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppCo }, config: config.ApiConfig, uploadManager: manager, + dallService: dallService, } } @@ -151,30 +160,6 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) { resp.SUCCESS(c, strings.Join(builder, "\n\n")) } -type imgReq struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - N int `json:"n"` - Size string `json:"size"` -} - -type imgRes struct { - Created int64 `json:"created"` - Data []struct { - RevisedPrompt string `json:"revised_prompt"` - Url string `json:"url"` - } `json:"data"` -} - -type ErrRes struct { - Error struct { - Code interface{} `json:"code"` - Message string `json:"message"` - Param interface{} `json:"param"` - Type string `json:"type"` - } `json:"error"` -} - // Dall3 DallE3 AI 绘图 func (h *FunctionHandler) Dall3(c *gin.Context) { if err := h.checkAuth(c); err != nil { @@ -190,85 +175,40 @@ func (h *FunctionHandler) Dall3(c *gin.Context) { logger.Debugf("绘画参数:%+v", params) var user model.User - tx := h.DB.Where("id = ?", params["user_id"]).First(&user) - if tx.Error != nil { + res := h.DB.Where("id = ?", params["user_id"]).First(&user) + if res.Error != nil { resp.ERROR(c, "当前用户不存在!") return } - if user.Power < h.App.SysConfig.DallPower { - resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") - return - } - + // create dall task prompt := utils.InterfaceToString(params["prompt"]) - // get image generation API KEY - var apiKey model.ApiKey - tx = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey) - if tx.Error != nil { - resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error()) + job := model.DallJob{ + UserId: user.Id, + Prompt: prompt, + Power: h.App.SysConfig.DallPower, + } + res = h.DB.Create(&job) + + if res.Error != nil { + resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+res.Error.Error()) return } - // translate prompt - const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" - pt, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, params["prompt"])) - if err == nil { - logger.Debugf("翻译绘画提示词,原文:%s,译文:%s", prompt, pt) - prompt = pt - } - var res imgRes - var errRes ErrRes - var request *req.Request - if len(apiKey.ProxyURL) > 5 { - request = req.C().SetProxyURL(apiKey.ProxyURL).R() - } else { - request = req.C().R() - } - logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL) - r, err := request.SetHeader("Content-Type", "application/json"). - SetHeader("Authorization", "Bearer "+apiKey.Value). - SetBody(imgReq{ - Model: "dall-e-3", - Prompt: prompt, - N: 1, - Size: "1024x1024", - }). - SetErrorResult(&errRes). - SetSuccessResult(&res).Post(apiKey.ApiURL) - if r.IsErrorState() { - resp.ERROR(c, "请求 OpenAI API 失败: "+errRes.Error.Message) - return - } - // 更新 API KEY 的最后使用时间 - h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) - logger.Debugf("%+v", res) - // 存储图片 - imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false) + content, err := h.dallService.Image(types.DallTask{ + JobId: job.Id, + UserId: user.Id, + Prompt: job.Prompt, + N: 1, + Quality: "standard", + Size: "1024x1024", + Style: "vivid", + Power: job.Power, + }, true) if err != nil { - resp.ERROR(c, "下载图片失败: "+err.Error()) + resp.ERROR(c, "任务执行失败:"+err.Error()) return } - content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n![](%s)\n", prompt, imgURL) - // 更新用户算力 - tx = h.DB.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", h.App.SysConfig.DallPower)) - // 记录算力变化日志 - if tx.Error == nil && tx.RowsAffected > 0 { - var u model.User - h.DB.Where("id", user.Id).First(&u) - h.DB.Create(&model.PowerLog{ - UserId: user.Id, - Username: user.Username, - Type: types.PowerConsume, - Amount: h.App.SysConfig.DallPower, - Balance: u.Power, - Mark: types.PowerSub, - Model: "dall-e-3", - Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(prompt, 10)), - CreatedAt: time.Now(), - }) - } - resp.SUCCESS(c, content) } diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index b9c3625e..4d9a03ec 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -65,7 +65,7 @@ func (h *SdJobHandler) Client(c *gin.Context) { logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) } -func (h *SdJobHandler) checkLimits(c *gin.Context) bool { +func (h *SdJobHandler) preCheck(c *gin.Context) bool { user, err := h.GetLoginUser(c) if err != nil { resp.NotAuth(c) @@ -88,7 +88,7 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool { // Image 创建一个绘画任务 func (h *SdJobHandler) Image(c *gin.Context) { - if !h.checkLimits(c) { + if !h.preCheck(c) { return } @@ -298,7 +298,7 @@ func (h *SdJobHandler) Remove(c *gin.Context) { client := h.pool.Clients.Get(data.UserId) if client != nil { - _ = client.Send([]byte("Task Updated")) + _ = client.Send([]byte(sd.Finished)) } resp.SUCCESS(c) diff --git a/api/main.go b/api/main.go index 586b7f70..56bf7caf 100644 --- a/api/main.go +++ b/api/main.go @@ -8,6 +8,7 @@ import ( "chatplus/handler/chatimpl" logger2 "chatplus/logger" "chatplus/service" + "chatplus/service/dalle" "chatplus/service/mj" "chatplus/service/oss" "chatplus/service/payment" @@ -153,6 +154,12 @@ func main() { }), fx.Provide(oss.NewUploaderManager), fx.Provide(mj.NewService), + fx.Provide(dalle.NewService), + fx.Invoke(func(service *dalle.Service) { + service.Run() + service.CheckTaskNotify() + service.DownloadImages() + }), // 邮件服务 fx.Provide(service.NewSmtpService), @@ -441,6 +448,16 @@ func main() { group := s.Engine.Group("/api/markMap/") group.Any("client", h.Client) }), + fx.Provide(handler.NewDallJobHandler), + fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) { + group := s.Engine.Group("/api/dall") + group.Any("client", h.Client) + group.POST("image", h.Image) + group.GET("jobs", h.JobList) + group.GET("imgWall", h.ImgWall) + group.POST("remove", h.Remove) + group.POST("publish", h.Publish) + }), fx.Invoke(func(s *core.AppServer, db *gorm.DB) { go func() { err := s.Run(db) diff --git a/api/service/dalle/service.go b/api/service/dalle/service.go new file mode 100644 index 00000000..8e8c8c2f --- /dev/null +++ b/api/service/dalle/service.go @@ -0,0 +1,259 @@ +package dalle + +import ( + "chatplus/core/types" + logger2 "chatplus/logger" + "chatplus/service" + "chatplus/service/oss" + "chatplus/service/sd" + "chatplus/store" + "chatplus/store/model" + "chatplus/utils" + "errors" + "fmt" + "github.com/go-redis/redis/v8" + "time" + + "github.com/imroc/req/v3" + "gorm.io/gorm" +) + +var logger = logger2.GetLogger() + +// DALL-E 绘画服务 + +type Service struct { + httpClient *req.Client + db *gorm.DB + uploadManager *oss.UploaderManager + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + Clients *types.LMap[uint, *types.WsClient] // UserId => Client +} + +func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { + return &Service{ + httpClient: req.C().SetTimeout(time.Minute * 3), + db: db, + taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli), + notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli), + Clients: types.NewLMap[uint, *types.WsClient](), + uploadManager: manager, + } +} + +// PushTask push a new mj task in to task queue +func (s *Service) PushTask(task types.DallTask) { + logger.Debugf("add a new MidJourney task to the task list: %+v", task) + s.taskQueue.RPush(task) +} + +func (s *Service) Run() { + go func() { + for { + var task types.DallTask + err := s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) + continue + } + + _, err = s.Image(task, false) + if err != nil { + logger.Errorf("error with image task: %v", err) + s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ + "progress": -1, + "err_msg": err.Error(), + }) + s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed}) + } + } + }() +} + +type imgReq struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + N int `json:"n"` + Size string `json:"size"` + Quality string `json:"quality"` + Style string `json:"style"` +} + +type imgRes struct { + Created int64 `json:"created"` + Data []struct { + RevisedPrompt string `json:"revised_prompt"` + Url string `json:"url"` + } `json:"data"` +} + +type ErrRes struct { + Error struct { + Code interface{} `json:"code"` + Message string `json:"message"` + Param interface{} `json:"param"` + Type string `json:"type"` + } `json:"error"` +} + +func (s *Service) Image(task types.DallTask, sync bool) (string, error) { + logger.Debugf("绘画参数:%+v", task) + prompt := task.Prompt + // translate prompt + if utils.HasChinese(task.Prompt) { + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt)) + if err != nil { + return "", fmt.Errorf("error with translate prompt: %v", err) + } + prompt = content + logger.Debugf("重写后提示词:%s", prompt) + } + + var user model.User + s.db.Where("id", task.UserId).First(&user) + if user.Power < task.Power { + return "", errors.New("insufficient of power") + } + + // get image generation API KEY + var apiKey model.ApiKey + tx := s.db.Where("platform", types.OpenAI). + Where("type", "img"). + Where("enabled", true). + Order("last_used_at ASC").First(&apiKey) + if tx.Error != nil { + return "", fmt.Errorf("no available IMG api key: %v", tx.Error) + } + + var res imgRes + var errRes ErrRes + if len(apiKey.ProxyURL) > 5 { + s.httpClient.SetProxyURL(apiKey.ProxyURL).R() + } + logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL) + r, err := s.httpClient.R().SetHeader("Content-Type", "application/json"). + SetHeader("Authorization", "Bearer "+apiKey.Value). + SetBody(imgReq{ + Model: "dall-e-3", + Prompt: prompt, + N: 1, + Size: "1024x1024", + Style: task.Style, + Quality: task.Quality, + }). + SetErrorResult(&errRes). + SetSuccessResult(&res).Post(apiKey.ApiURL) + if err != nil { + return "", fmt.Errorf("error with send request: %v", err) + } + + if r.IsErrorState() { + return "", fmt.Errorf("error with send request: %v", errRes.Error) + } + // update the api key last use time + s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) + // update task progress + s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ + "progress": 100, + "org_url": res.Data[0].Url, + "prompt": prompt, + }) + + s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished}) + var content string + if sync { + imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url) + if err != nil { + return "", fmt.Errorf("error with download image: %v", err) + } + content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n![](%s)\n", prompt, imgURL) + } + + // 更新用户算力 + tx = s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power)) + // 记录算力变化日志 + if tx.Error == nil && tx.RowsAffected > 0 { + var u model.User + s.db.Where("id", user.Id).First(&u) + s.db.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerConsume, + Amount: task.Power, + Balance: u.Power, + Mark: types.PowerSub, + Model: "dall-e-3", + Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)), + CreatedAt: time.Now(), + }) + } + + return content, nil +} + +func (s *Service) CheckTaskNotify() { + go func() { + logger.Info("Running DALL-E task notify checking ...") + for { + var message sd.NotifyMessage + err := s.notifyQueue.LPop(&message) + if err != nil { + continue + } + client := s.Clients.Get(uint(message.UserId)) + if client == nil { + continue + } + err = client.Send([]byte(message.Message)) + if err != nil { + continue + } + } + }() +} + +func (s *Service) DownloadImages() { + go func() { + var items []model.DallJob + for { + res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items) + if res.Error != nil { + continue + } + + // download images + for _, v := range items { + if v.OrgURL == "" { + continue + } + + logger.Infof("try to download image: %s", v.OrgURL) + imgURL, err := s.downloadImage(v.Id, int(v.UserId), v.OrgURL) + if err != nil { + logger.Error("error with download image: %s, error: %v", imgURL, err) + continue + } + + } + + time.Sleep(time.Second * 5) + } + }() +} + +func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) { + // sava image + imgURL, err := s.uploadManager.GetUploadHandler().PutImg(orgURL, false) + if err != nil { + return "", err + } + + // update img_url + res := s.db.Model(&model.DallJob{Id: jobId}).UpdateColumn("img_url", imgURL) + if res.Error != nil { + return "", err + } + s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Failed}) + return imgURL, nil +} diff --git a/api/service/types.go b/api/service/types.go index 9a8a0d00..15a538a2 100644 --- a/api/service/types.go +++ b/api/service/types.go @@ -1,4 +1,4 @@ package service -const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]" +const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]" const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" diff --git a/api/store/model/dalle_job.go b/api/store/model/dalle_job.go index 56bbbcd7..de7a13a0 100644 --- a/api/store/model/dalle_job.go +++ b/api/store/model/dalle_job.go @@ -4,13 +4,13 @@ import "time" type DallJob struct { Id uint `gorm:"primarykey;column:id"` - UserId int - TaskId string + UserId uint Prompt string - ImgURL string - Publish bool - Power int - Progress int - ErrMsg string + ImgURL string + OrgURL string + Publish bool + Power int + Progress int + ErrMsg string CreatedAt time.Time } diff --git a/api/store/vo/dalle_job.go b/api/store/vo/dalle_job.go index d7ca4df1..28a6906d 100644 --- a/api/store/vo/dalle_job.go +++ b/api/store/vo/dalle_job.go @@ -1,14 +1,14 @@ package vo type DallJob struct { - Id uint `json:"id"` - UserId int `json:"user_id"` - TaskId string `json:"task_id"` + Id uint `json:"id"` + UserId int `json:"user_id"` Prompt string `json:"prompt"` - ImgURL string `json:"img_url"` - Publish bool `json:"publish"` - Power int `json:"power"` - Progress int `json:"progress"` - ErrMsg string `json:"err_msg"` - CreatedAt int64 `json:"created_at"` + ImgURL string `json:"img_url"` + OrgURL string `json:"org_url"` + Publish bool `json:"publish"` + Power int `json:"power"` + Progress int `json:"progress"` + ErrMsg string `json:"err_msg"` + CreatedAt int64 `json:"created_at"` } diff --git a/database/update-v4.0.4.sql b/database/update-v4.0.4.sql index 3be95643..d61cabb0 100644 --- a/database/update-v4.0.4.sql +++ b/database/update-v4.0.4.sql @@ -1 +1,6 @@ -CREATE TABLE `chatgpt_plus`.`chatgpt_dalle` ( `id` INT(11) NOT NULL AUTO_INCREMENT , `user_id` INT(11) NOT NULL COMMENT '用户ID' , `task_id` VARCHAR(20) NOT NULL COMMENT '任务ID' , `prompt` VARCHAR(2000) NOT NULL COMMENT '提示词' , `img_url` VARCHAR(255) NOT NULL COMMENT '图片地址' , `publish` TINYINT(1) NOT NULL COMMENT '是否发布' , `power` SMALLINT(3) NOT NULL COMMENT '消耗算力' , `progress` SMALLINT(3) NOT NULL COMMENT '任务进度' , `err_msg` VARCHAR(255) NOT NULL COMMENT '错误信息' , `created_at` DATETIME NOT NULL , PRIMARY KEY (`id`)) ENGINE = InnoDB COMMENT = 'DALLE 绘图任务表'; \ No newline at end of file +CREATE TABLE `chatgpt_plus`.`chatgpt_dall_jobs` ( `id` INT(11) NOT NULL AUTO_INCREMENT , `user_id` INT(11) NOT NULL COMMENT '用户ID' , `task_id` VARCHAR(20) NOT NULL COMMENT '任务ID' , `prompt` VARCHAR(2000) NOT NULL COMMENT '提示词' , `img_url` VARCHAR(255) NOT NULL COMMENT '图片地址' , `publish` TINYINT(1) NOT NULL COMMENT '是否发布' , `power` SMALLINT(3) NOT NULL COMMENT '消耗算力' , `progress` SMALLINT(3) NOT NULL COMMENT '任务进度' , `err_msg` VARCHAR(255) NOT NULL COMMENT '错误信息' , `created_at` DATETIME NOT NULL , PRIMARY KEY (`id`)) ENGINE = InnoDB COMMENT = 'DALLE 绘图任务表'; + +ALTER TABLE `chatgpt_dall_jobs` ADD `org_url` VARCHAR(400) NULL COMMENT '原图地址' AFTER `img_url`; +ALTER TABLE `chatgpt_dall_jobs` DROP `task_id`; + + diff --git a/web/src/assets/css/image-dall.styl b/web/src/assets/css/image-dall.styl new file mode 100644 index 00000000..caf514a0 --- /dev/null +++ b/web/src/assets/css/image-dall.styl @@ -0,0 +1,88 @@ +.page-dall { + background-color: #282c34; + + .inner { + display: flex; + + .sd-box { + margin 10px + background-color #262626 + border 1px solid #454545 + min-width 300px + max-width 300px + padding 10px + border-radius 10px + color #ffffff; + font-size 14px + + h2 { + font-weight: bold; + font-size 20px + text-align center + color #47fff1 + } + + // 隐藏滚动条 + + ::-webkit-scrollbar { + width: 0; + height: 0; + background-color: transparent; + } + + .sd-params { + margin-top 10px + overflow auto + + + .param-line { + padding 0 10px + + .grid-content + .form-item-inner { + display flex + + .info-icon { + margin-left 10px + position relative + top 8px + } + } + + } + + .param-line.pt { + padding-top 5px + padding-bottom 5px + } + + .text-info { + padding 10px + } + } + + .submit-btn { + padding 10px 15px 0 15px + text-align center + + .el-button { + width 100% + + span { + color #2D3A4B + } + } + } + } + + .el-form { + .el-form-item__label { + color #ffffff + } + } + + @import "task-list.styl" + } + +} + diff --git a/web/src/assets/css/image-sd.styl b/web/src/assets/css/image-sd.styl index 1408a597..ba860bc3 100644 --- a/web/src/assets/css/image-sd.styl +++ b/web/src/assets/css/image-sd.styl @@ -58,10 +58,6 @@ .text-info { padding 10px - - .el-tag { - margin-right 10px - } } } diff --git a/web/src/utils/http.js b/web/src/utils/http.js index dd0974cb..7f5b3caf 100644 --- a/web/src/utils/http.js +++ b/web/src/utils/http.js @@ -1,7 +1,7 @@ import axios from 'axios' import {getAdminToken, getSessionId, getUserToken} from "@/store/session"; -axios.defaults.timeout = 30000 +axios.defaults.timeout = 180000 axios.defaults.baseURL = process.env.VUE_APP_API_HOST axios.defaults.withCredentials = true; axios.defaults.headers.post['Content-Type'] = 'application/json' diff --git a/web/src/views/Dalle.vue b/web/src/views/Dalle.vue index c81a9cc9..8e219175 100644 --- a/web/src/views/Dalle.vue +++ b/web/src/views/Dalle.vue @@ -1,29 +1,19 @@