From ce8fa792067c2e28fdf5f908bf1c65b77e3ea532 Mon Sep 17 00:00:00 2001 From: RockYang Date: Tue, 28 Nov 2023 12:04:02 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=B8=BA=E5=A4=A7=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E7=94=9F=E6=88=90=E7=BC=A9=E7=95=A5=E5=9B=BE=EF=BC=8C=E5=8A=A0?= =?UTF-8?q?=E5=BF=AB=E5=89=8D=E7=AB=AF=E5=9B=BE=E7=89=87=E5=8A=A0=E8=BD=BD?= =?UTF-8?q?=E9=80=9F=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/core/app_server.go | 55 +++++++++++++++++++++++++++++++++++++++ api/service/mj/service.go | 14 +++++++--- api/test/test.go | 37 -------------------------- web/src/views/ImageMj.vue | 27 +++++++++---------- 4 files changed, 77 insertions(+), 56 deletions(-) diff --git a/api/core/app_server.go b/api/core/app_server.go index ef3032ae..cb2241ae 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -12,9 +12,14 @@ import ( "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" "github.com/golang-jwt/jwt/v5" + "github.com/nfnt/resize" "gorm.io/gorm" + "image" + "image/jpeg" "io" + "log" "net/http" + "os" "runtime/debug" "strings" "time" @@ -58,6 +63,7 @@ func (s *AppServer) Init(debug bool, client *redis.Client) { logger.Info("Enabled debug mode") } s.Engine.Use(corsMiddleware()) + s.Engine.Use(staticResourceMiddleware()) s.Engine.Use(authorizeMiddleware(s, client)) s.Engine.Use(parameterHandlerMiddleware()) s.Engine.Use(errorHandler) @@ -274,3 +280,52 @@ func trimJSONStrings(data interface{}) { } } } + +// 静态资源中间件 +func staticResourceMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + + url := c.Request.URL.String() + // 拦截生成缩略图请求 + if strings.HasPrefix(url, "/static/") && strings.Contains(url, "?imageView2") { + r := strings.SplitAfter(url, "imageView2") + size := strings.Split(r[1], "/") + if len(size) != 8 { + c.String(http.StatusNotFound, "invalid thumb args") + return + } + with := utils.IntValue(size[3], 0) + height := utils.IntValue(size[5], 0) + quality := utils.IntValue(size[7], 75) + + // 打开图片文件 + filePath := strings.TrimLeft(c.Request.URL.Path, "/") + file, err := os.Open(filePath) + if err != nil { + c.String(http.StatusNotFound, "Image not found") + return + } + defer file.Close() + + // 解码图片 + img, _, err := image.Decode(file) + if err != nil { + c.String(http.StatusInternalServerError, "Error decoding image") + return + } + + // 生成缩略图 + resizedImg := resize.Thumbnail(uint(with), uint(height), img, resize.Lanczos3) + var buffer bytes.Buffer + err = jpeg.Encode(&buffer, resizedImg, &jpeg.Options{Quality: quality}) + if err != nil { + log.Fatal(err) + } + + // 直接输出图像数据流 + c.Data(http.StatusOK, "image/jpeg", buffer.Bytes()) + return + } + c.Next() + } +} diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 9ab320a2..ffabd73c 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -72,11 +72,17 @@ func (s *Service) Run() { } if err != nil { logger.Error("绘画任务执行失败:", err) - if task.RetryCount <= 5 { - s.taskQueue.RPush(task) + // 推送任务到前端 + client := s.Clients.Get(task.SessionId) + if client != nil { + utils.ReplyChunkMessage(client, vo.MidJourneyJob{ + Type: task.Type.String(), + UserId: task.UserId, + MessageId: task.MessageId, + Progress: -1, + Prompt: task.Prompt, + }) } - task.RetryCount += 1 - time.Sleep(time.Second * 3) continue } diff --git a/api/test/test.go b/api/test/test.go index dd4d1c31..79058077 100644 --- a/api/test/test.go +++ b/api/test/test.go @@ -1,42 +1,5 @@ package main -import ( - "chatplus/store/model" - "chatplus/utils" - "fmt" - "gorm.io/driver/mysql" - "gorm.io/gorm" - "log" - "os" - "path" -) - func main() { - MysqlDns := "root:12345678@tcp(localhost:3306)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local" - db, err := gorm.Open(mysql.Open(MysqlDns), &gorm.Config{}) - if err != nil { - log.Fatal(err) - } - _ = os.MkdirAll("static/upload/images", 0755) - var jobs []model.MidJourneyJob - db.Find(&jobs) - for _, job := range jobs { - basename := path.Base(job.ImgURL) - imageData, err := utils.DownloadImage(job.ImgURL, "") - if err != nil { - fmt.Println("图片下载失败:" + job.ImgURL) - continue - } - newImagePath := fmt.Sprintf("static/upload/images/%s", basename) - err = os.WriteFile(newImagePath, imageData, 0644) - if err != nil { - fmt.Println("Error writing image file:", err) - continue - } - fmt.Println("图片保存成功!", newImagePath) - // 更新数据库 - job.ImgURL = fmt.Sprintf("http://localhost:5678/%s", newImagePath) - db.Updates(&job) - } } diff --git a/web/src/views/ImageMj.vue b/web/src/views/ImageMj.vue index 9b055ea5..4b4513ca 100644 --- a/web/src/views/ImageMj.vue +++ b/web/src/views/ImageMj.vue @@ -354,12 +354,14 @@ import {onMounted, ref} from "vue" import {ChromeFilled, DeleteFilled, DocumentCopy, InfoFilled, Picture, Plus} from "@element-plus/icons-vue"; import Compressor from "compressorjs"; import {httpGet, httpPost} from "@/utils/http"; -import {ElMessage} from "element-plus"; +import {ElMessage, ElNotification} from "element-plus"; import ItemList from "@/components/ItemList.vue"; import Clipboard from "clipboard"; import {checkSession} from "@/action/session"; import {useRouter} from "vue-router"; import {getSessionId, getUserToken} from "@/store/session"; +import {removeArrayItem} from "@/utils/libs"; +import axios from "axios"; const listBoxHeight = ref(window.innerHeight - 40) const mjBoxHeight = ref(window.innerHeight - 150) @@ -432,6 +434,14 @@ const connect = () => { if (isNew) { finishedJobs.value.unshift(data) } + } else if (data.progress === -1) { // 任务执行失败 + ElNotification({ + title: '任务执行失败', + message: "提示词:" + data['prompt'], + type: 'error', + }) + runningJobs.value = removeArrayItem(runningJobs.value, data, (v1, v2) => v1.id === v2.id) + } else { for (let i = 0; i < runningJobs.value.length; i++) { if (runningJobs.value[i].id === data.id) { @@ -463,7 +473,7 @@ onMounted(() => { ElMessage.error("获取任务失败:" + e.message) }) - // 获取运行中的任务 + // 获取已完成的任务 httpGet(`/api/mj/jobs?status=1&user_id=${user['id']}`).then(res => { finishedJobs.value = res.data }).catch(e => { @@ -516,19 +526,6 @@ const afterRead = (file) => { }, }); }; - -const getTaskType = (type) => { - switch (type) { - case "image": - return "绘画任务" - case "upscale": - return "放大任务" - case "variation": - return "变化任务" - } - return "未知任务" -} - // 创建绘图任务 const promptRef = ref(null) const generate = () => {