dalle image page is ready

This commit is contained in:
RockYang 2024-04-21 20:23:47 +08:00
parent ab8240613e
commit 60cf380f96
21 changed files with 899 additions and 603 deletions

View File

@ -3,6 +3,7 @@
## v4.0.4 ## v4.0.4
* Bug修复修复统一千问第二句不回复的问题 * Bug修复修复统一千问第二句不回复的问题
* 功能优化MJ 和 SD 任务正在执行时不更新已完成任务列表
* 功能新增Dalle AI 绘画功能实现 * 功能新增Dalle AI 绘画功能实现
## v4.0.3 ## v4.0.3

View File

@ -8,6 +8,7 @@ import (
"chatplus/utils/resp" "chatplus/utils/resp"
"context" "context"
"fmt" "fmt"
"github.com/chai2010/webp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
@ -16,7 +17,6 @@ import (
"image" "image"
"image/jpeg" "image/jpeg"
"io" "io"
"log"
"net/http" "net/http"
"os" "os"
"runtime/debug" "runtime/debug"
@ -215,6 +215,8 @@ func needLogin(c *gin.Context) bool {
c.Request.URL.Path == "/api/invite/hits" || c.Request.URL.Path == "/api/invite/hits" ||
c.Request.URL.Path == "/api/sd/imgWall" || c.Request.URL.Path == "/api/sd/imgWall" ||
c.Request.URL.Path == "/api/sd/client" || c.Request.URL.Path == "/api/sd/client" ||
c.Request.URL.Path == "/api/dall/imgWall" ||
c.Request.URL.Path == "/api/dall/client" ||
c.Request.URL.Path == "/api/config/get" || c.Request.URL.Path == "/api/config/get" ||
c.Request.URL.Path == "/api/product/list" || c.Request.URL.Path == "/api/product/list" ||
c.Request.URL.Path == "/api/menu/list" || c.Request.URL.Path == "/api/menu/list" ||
@ -328,6 +330,10 @@ func staticResourceMiddleware() gin.HandlerFunc {
// 解码图片 // 解码图片
img, _, err := image.Decode(file) img, _, err := image.Decode(file)
// for .webp image
if err != nil {
img, err = webp.Decode(file)
}
if err != nil { if err != nil {
c.String(http.StatusInternalServerError, "Error decoding image") c.String(http.StatusInternalServerError, "Error decoding image")
return return
@ -344,7 +350,9 @@ func staticResourceMiddleware() gin.HandlerFunc {
var buffer bytes.Buffer var buffer bytes.Buffer
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality}) err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
if err != nil { if err != nil {
log.Fatal(err) logger.Error(err)
c.String(http.StatusInternalServerError, err.Error())
return
} }
// 设置图片缓存有效期为一年 (365天) // 设置图片缓存有效期为一年 (365天)

View File

@ -59,3 +59,16 @@ type SdTaskParams struct {
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法 HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
HdSteps int `json:"hd_steps"` // 高清修复迭代步数 HdSteps int `json:"hd_steps"` // 高清修复迭代步数
} }
// DallTask DALL-E task
type DallTask struct {
JobId uint `json:"job_id"`
UserId uint `json:"user_id"`
Prompt string `json:"prompt"`
N int `json:"n"`
Quality string `json:"quality"`
Size string `json:"size"`
Style string `json:"style"`
Power int `json:"power"`
}

View File

@ -32,9 +32,10 @@ require (
) )
require ( require (
github.com/chai2010/webp v1.1.1 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 // indirect golang.org/x/image v0.0.0-20211028202545-6944b10bf410 // indirect
) )
require ( require (

View File

@ -12,6 +12,8 @@ github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk=
github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
@ -241,6 +243,8 @@ golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERs
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ= golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=

View File

@ -104,8 +104,10 @@ func (h *ChatHandler) sendOpenAiMessage(
res := h.DB.Where("name = ?", tool.Function.Name).First(&function) res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
if res.Error == nil { if res.Error == nil {
toolCall = true toolCall = true
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)}) utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
contents = append(contents, callMsg)
} }
continue continue
} }

View File

@ -0,0 +1,260 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service/dalle"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type DallJobHandler struct {
BaseHandler
redis *redis.Client
service *dalle.Service
uploader *oss.UploaderManager
}
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler {
return &DallJobHandler{
service: service,
uploader: manager,
BaseHandler: BaseHandler{
App: app,
DB: db,
},
}
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *DallJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.service.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
client.Close()
h.service.Clients.Delete(uint(userId))
return
}
var message types.WsMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 DallE 心跳消息:", message.Content)
continue
}
}
}()
}
func (h *DallJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return false
}
if user.Power < h.App.SysConfig.SdPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false
}
return true
}
// Image 创建一个绘画任务
func (h *DallJobHandler) Image(c *gin.Context) {
if !h.preCheck(c) {
return
}
var data types.DallTask
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
job := model.DallJob{
UserId: uint(userId),
Prompt: data.Prompt,
Power: h.App.SysConfig.DallPower,
}
res := h.DB.Create(&job)
if res.Error != nil {
resp.ERROR(c, "error with save job: "+res.Error.Error())
return
}
h.service.PushTask(types.DallTask{
JobId: job.Id,
UserId: uint(userId),
Prompt: data.Prompt,
Quality: data.Quality,
Size: data.Size,
Style: data.Style,
Power: job.Power,
})
client := h.service.Clients.Get(job.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
// ImgWall 照片墙
func (h *DallJobHandler) ImgWall(c *gin.Context) {
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
err, jobs := h.getData(true, 0, page, pageSize, true)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 SD 任务列表
func (h *DallJobHandler) JobList(c *gin.Context) {
status := h.GetBool(c, "status")
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish")
err, jobs := h.getData(status, userId, page, pageSize, publish)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取任务列表
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
} else {
session = session.Where("progress < ?", 100).Order("id ASC")
}
if userId > 0 {
session = session.Where("user_id = ?", userId)
}
if publish {
session = session.Where("publish", publish)
}
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
}
var items []model.DallJob
res := session.Find(&items)
if res.Error != nil {
return res.Error, nil
}
var jobs = make([]vo.DallJob, 0)
for _, item := range items {
// delete failed or timeout tasks
if (item.Progress < 100 && time.Now().Sub(item.CreatedAt) > time.Minute*5) || item.Progress == -1 {
h.DB.Delete(&item)
}
var job vo.DallJob
err := utils.CopyObject(item, &job)
if err != nil {
continue
}
jobs = append(jobs, job)
}
return nil, jobs
}
// Remove remove task image
func (h *DallJobHandler) Remove(c *gin.Context) {
var data struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
ImgURL string `json:"img_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// remove job recode
res := h.DB.Delete(&model.DallJob{Id: data.Id})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
// remove image
err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
resp.SUCCESS(c)
}
// Publish 发布/取消发布图片到画廊显示
func (h *DallJobHandler) Publish(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.DallJob{Id: data.Id}).UpdateColumn("publish", true)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return
}
resp.SUCCESS(c)
}

View File

@ -3,27 +3,35 @@ package handler
import ( import (
"chatplus/core" "chatplus/core"
"chatplus/core/types" "chatplus/core/types"
"chatplus/service/dalle"
"chatplus/service/oss" "chatplus/service/oss"
"chatplus/store/model" "chatplus/store/model"
"chatplus/utils" "chatplus/utils"
"chatplus/utils/resp" "chatplus/utils/resp"
"errors" "errors"
"fmt" "fmt"
"strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
"gorm.io/gorm" "gorm.io/gorm"
"strings"
"time"
) )
type FunctionHandler struct { type FunctionHandler struct {
BaseHandler BaseHandler
config types.ChatPlusApiConfig config types.ChatPlusApiConfig
uploadManager *oss.UploaderManager uploadManager *oss.UploaderManager
dallService *dalle.Service
} }
func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler { func NewFunctionHandler(
server *core.AppServer,
db *gorm.DB,
config *types.AppConfig,
manager *oss.UploaderManager,
dallService *dalle.Service) *FunctionHandler {
return &FunctionHandler{ return &FunctionHandler{
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: server, App: server,
@ -31,6 +39,7 @@ func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppCo
}, },
config: config.ApiConfig, config: config.ApiConfig,
uploadManager: manager, uploadManager: manager,
dallService: dallService,
} }
} }
@ -151,30 +160,6 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) {
resp.SUCCESS(c, strings.Join(builder, "\n\n")) resp.SUCCESS(c, strings.Join(builder, "\n\n"))
} }
type imgReq struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
}
type imgRes struct {
Created int64 `json:"created"`
Data []struct {
RevisedPrompt string `json:"revised_prompt"`
Url string `json:"url"`
} `json:"data"`
}
type ErrRes struct {
Error struct {
Code interface{} `json:"code"`
Message string `json:"message"`
Param interface{} `json:"param"`
Type string `json:"type"`
} `json:"error"`
}
// Dall3 DallE3 AI 绘图 // Dall3 DallE3 AI 绘图
func (h *FunctionHandler) Dall3(c *gin.Context) { func (h *FunctionHandler) Dall3(c *gin.Context) {
if err := h.checkAuth(c); err != nil { if err := h.checkAuth(c); err != nil {
@ -190,85 +175,40 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
logger.Debugf("绘画参数:%+v", params) logger.Debugf("绘画参数:%+v", params)
var user model.User var user model.User
tx := h.DB.Where("id = ?", params["user_id"]).First(&user) res := h.DB.Where("id = ?", params["user_id"]).First(&user)
if tx.Error != nil { if res.Error != nil {
resp.ERROR(c, "当前用户不存在!") resp.ERROR(c, "当前用户不存在!")
return return
} }
if user.Power < h.App.SysConfig.DallPower { // create dall task
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return
}
prompt := utils.InterfaceToString(params["prompt"]) prompt := utils.InterfaceToString(params["prompt"])
// get image generation API KEY job := model.DallJob{
var apiKey model.ApiKey UserId: user.Id,
tx = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey) Prompt: prompt,
if tx.Error != nil { Power: h.App.SysConfig.DallPower,
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error()) }
res = h.DB.Create(&job)
if res.Error != nil {
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+res.Error.Error())
return return
} }
// translate prompt content, err := h.dallService.Image(types.DallTask{
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" JobId: job.Id,
pt, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, params["prompt"])) UserId: user.Id,
if err == nil { Prompt: job.Prompt,
logger.Debugf("翻译绘画提示词,原文:%s译文%s", prompt, pt) N: 1,
prompt = pt Quality: "standard",
} Size: "1024x1024",
var res imgRes Style: "vivid",
var errRes ErrRes Power: job.Power,
var request *req.Request }, true)
if len(apiKey.ProxyURL) > 5 {
request = req.C().SetProxyURL(apiKey.ProxyURL).R()
} else {
request = req.C().R()
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
r, err := request.SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: "1024x1024",
}).
SetErrorResult(&errRes).
SetSuccessResult(&res).Post(apiKey.ApiURL)
if r.IsErrorState() {
resp.ERROR(c, "请求 OpenAI API 失败: "+errRes.Error.Message)
return
}
// 更新 API KEY 的最后使用时间
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
logger.Debugf("%+v", res)
// 存储图片
imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false)
if err != nil { if err != nil {
resp.ERROR(c, "下载图片失败: "+err.Error()) resp.ERROR(c, "任务执行失败:"+err.Error())
return return
} }
content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n![](%s)\n", prompt, imgURL)
// 更新用户算力
tx = h.DB.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", h.App.SysConfig.DallPower))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
h.DB.Where("id", user.Id).First(&u)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: h.App.SysConfig.DallPower,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(prompt, 10)),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c, content) resp.SUCCESS(c, content)
} }

View File

@ -65,7 +65,7 @@ func (h *SdJobHandler) Client(c *gin.Context) {
logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
} }
func (h *SdJobHandler) checkLimits(c *gin.Context) bool { func (h *SdJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c) user, err := h.GetLoginUser(c)
if err != nil { if err != nil {
resp.NotAuth(c) resp.NotAuth(c)
@ -88,7 +88,7 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
// Image 创建一个绘画任务 // Image 创建一个绘画任务
func (h *SdJobHandler) Image(c *gin.Context) { func (h *SdJobHandler) Image(c *gin.Context) {
if !h.checkLimits(c) { if !h.preCheck(c) {
return return
} }
@ -298,7 +298,7 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
client := h.pool.Clients.Get(data.UserId) client := h.pool.Clients.Get(data.UserId)
if client != nil { if client != nil {
_ = client.Send([]byte("Task Updated")) _ = client.Send([]byte(sd.Finished))
} }
resp.SUCCESS(c) resp.SUCCESS(c)

View File

@ -8,6 +8,7 @@ import (
"chatplus/handler/chatimpl" "chatplus/handler/chatimpl"
logger2 "chatplus/logger" logger2 "chatplus/logger"
"chatplus/service" "chatplus/service"
"chatplus/service/dalle"
"chatplus/service/mj" "chatplus/service/mj"
"chatplus/service/oss" "chatplus/service/oss"
"chatplus/service/payment" "chatplus/service/payment"
@ -153,6 +154,12 @@ func main() {
}), }),
fx.Provide(oss.NewUploaderManager), fx.Provide(oss.NewUploaderManager),
fx.Provide(mj.NewService), fx.Provide(mj.NewService),
fx.Provide(dalle.NewService),
fx.Invoke(func(service *dalle.Service) {
service.Run()
service.CheckTaskNotify()
service.DownloadImages()
}),
// 邮件服务 // 邮件服务
fx.Provide(service.NewSmtpService), fx.Provide(service.NewSmtpService),
@ -441,6 +448,16 @@ func main() {
group := s.Engine.Group("/api/markMap/") group := s.Engine.Group("/api/markMap/")
group.Any("client", h.Client) group.Any("client", h.Client)
}), }),
fx.Provide(handler.NewDallJobHandler),
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
group := s.Engine.Group("/api/dall")
group.Any("client", h.Client)
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove)
group.POST("publish", h.Publish)
}),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) { fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
go func() { go func() {
err := s.Run(db) err := s.Run(db)

View File

@ -0,0 +1,259 @@
package dalle
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/service"
"chatplus/service/oss"
"chatplus/service/sd"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
"errors"
"fmt"
"github.com/go-redis/redis/v8"
"time"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
// DALL-E 绘画服务
type Service struct {
httpClient *req.Client
db *gorm.DB
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
uploadManager: manager,
}
}
// PushTask push a new mj task in to task queue
func (s *Service) PushTask(task types.DallTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) Run() {
go func() {
for {
var task types.DallTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
_, err = s.Image(task, false)
if err != nil {
logger.Errorf("error with image task: %v", err)
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": err.Error(),
})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed})
}
}
}()
}
type imgReq struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
Quality string `json:"quality"`
Style string `json:"style"`
}
type imgRes struct {
Created int64 `json:"created"`
Data []struct {
RevisedPrompt string `json:"revised_prompt"`
Url string `json:"url"`
} `json:"data"`
}
type ErrRes struct {
Error struct {
Code interface{} `json:"code"`
Message string `json:"message"`
Param interface{} `json:"param"`
Type string `json:"type"`
} `json:"error"`
}
func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
logger.Debugf("绘画参数:%+v", task)
prompt := task.Prompt
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
if err != nil {
return "", fmt.Errorf("error with translate prompt: %v", err)
}
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
}
var user model.User
s.db.Where("id", task.UserId).First(&user)
if user.Power < task.Power {
return "", errors.New("insufficient of power")
}
// get image generation API KEY
var apiKey model.ApiKey
tx := s.db.Where("platform", types.OpenAI).
Where("type", "img").
Where("enabled", true).
Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
return "", fmt.Errorf("no available IMG api key: %v", tx.Error)
}
var res imgRes
var errRes ErrRes
if len(apiKey.ProxyURL) > 5 {
s.httpClient.SetProxyURL(apiKey.ProxyURL).R()
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: "1024x1024",
Style: task.Style,
Quality: task.Quality,
}).
SetErrorResult(&errRes).
SetSuccessResult(&res).Post(apiKey.ApiURL)
if err != nil {
return "", fmt.Errorf("error with send request: %v", err)
}
if r.IsErrorState() {
return "", fmt.Errorf("error with send request: %v", errRes.Error)
}
// update the api key last use time
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// update task progress
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
"progress": 100,
"org_url": res.Data[0].Url,
"prompt": prompt,
})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished})
var content string
if sync {
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片\n\n![](%s)\n", prompt, imgURL)
}
// 更新用户算力
tx = s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
s.db.Where("id", user.Id).First(&u)
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: task.Power,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
CreatedAt: time.Now(),
})
}
return content, nil
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running DALL-E task notify checking ...")
for {
var message sd.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := s.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (s *Service) DownloadImages() {
go func() {
var items []model.DallJob
for {
res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
if res.Error != nil {
continue
}
// download images
for _, v := range items {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
imgURL, err := s.downloadImage(v.Id, int(v.UserId), v.OrgURL)
if err != nil {
logger.Error("error with download image: %s, error: %v", imgURL, err)
continue
}
}
time.Sleep(time.Second * 5)
}
}()
}
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) {
// sava image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(orgURL, false)
if err != nil {
return "", err
}
// update img_url
res := s.db.Model(&model.DallJob{Id: jobId}).UpdateColumn("img_url", imgURL)
if res.Error != nil {
return "", err
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Failed})
return imgURL, nil
}

View File

@ -1,4 +1,4 @@
package service package service
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]" const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]"
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"

View File

@ -4,13 +4,13 @@ import "time"
type DallJob struct { type DallJob struct {
Id uint `gorm:"primarykey;column:id"` Id uint `gorm:"primarykey;column:id"`
UserId int UserId uint
TaskId string
Prompt string Prompt string
ImgURL string ImgURL string
Publish bool OrgURL string
Power int Publish bool
Progress int Power int
ErrMsg string Progress int
ErrMsg string
CreatedAt time.Time CreatedAt time.Time
} }

View File

@ -1,14 +1,14 @@
package vo package vo
type DallJob struct { type DallJob struct {
Id uint `json:"id"` Id uint `json:"id"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
ImgURL string `json:"img_url"` ImgURL string `json:"img_url"`
Publish bool `json:"publish"` OrgURL string `json:"org_url"`
Power int `json:"power"` Publish bool `json:"publish"`
Progress int `json:"progress"` Power int `json:"power"`
ErrMsg string `json:"err_msg"` Progress int `json:"progress"`
CreatedAt int64 `json:"created_at"` ErrMsg string `json:"err_msg"`
CreatedAt int64 `json:"created_at"`
} }

View File

@ -1 +1,6 @@
CREATE TABLE `chatgpt_plus`.`chatgpt_dalle` ( `id` INT(11) NOT NULL AUTO_INCREMENT , `user_id` INT(11) NOT NULL COMMENT '用户ID' , `task_id` VARCHAR(20) NOT NULL COMMENT '任务ID' , `prompt` VARCHAR(2000) NOT NULL COMMENT '提示词' , `img_url` VARCHAR(255) NOT NULL COMMENT '图片地址' , `publish` TINYINT(1) NOT NULL COMMENT '是否发布' , `power` SMALLINT(3) NOT NULL COMMENT '消耗算力' , `progress` SMALLINT(3) NOT NULL COMMENT '任务进度' , `err_msg` VARCHAR(255) NOT NULL COMMENT '错误信息' , `created_at` DATETIME NOT NULL , PRIMARY KEY (`id`)) ENGINE = InnoDB COMMENT = 'DALLE 绘图任务表'; CREATE TABLE `chatgpt_plus`.`chatgpt_dall_jobs` ( `id` INT(11) NOT NULL AUTO_INCREMENT , `user_id` INT(11) NOT NULL COMMENT '用户ID' , `task_id` VARCHAR(20) NOT NULL COMMENT '任务ID' , `prompt` VARCHAR(2000) NOT NULL COMMENT '提示词' , `img_url` VARCHAR(255) NOT NULL COMMENT '图片地址' , `publish` TINYINT(1) NOT NULL COMMENT '是否发布' , `power` SMALLINT(3) NOT NULL COMMENT '消耗算力' , `progress` SMALLINT(3) NOT NULL COMMENT '任务进度' , `err_msg` VARCHAR(255) NOT NULL COMMENT '错误信息' , `created_at` DATETIME NOT NULL , PRIMARY KEY (`id`)) ENGINE = InnoDB COMMENT = 'DALLE 绘图任务表';
ALTER TABLE `chatgpt_dall_jobs` ADD `org_url` VARCHAR(400) NULL COMMENT '原图地址' AFTER `img_url`;
ALTER TABLE `chatgpt_dall_jobs` DROP `task_id`;

View File

@ -0,0 +1,88 @@
.page-dall {
background-color: #282c34;
.inner {
display: flex;
.sd-box {
margin 10px
background-color #262626
border 1px solid #454545
min-width 300px
max-width 300px
padding 10px
border-radius 10px
color #ffffff;
font-size 14px
h2 {
font-weight: bold;
font-size 20px
text-align center
color #47fff1
}
//
::-webkit-scrollbar {
width: 0;
height: 0;
background-color: transparent;
}
.sd-params {
margin-top 10px
overflow auto
.param-line {
padding 0 10px
.grid-content
.form-item-inner {
display flex
.info-icon {
margin-left 10px
position relative
top 8px
}
}
}
.param-line.pt {
padding-top 5px
padding-bottom 5px
}
.text-info {
padding 10px
}
}
.submit-btn {
padding 10px 15px 0 15px
text-align center
.el-button {
width 100%
span {
color #2D3A4B
}
}
}
}
.el-form {
.el-form-item__label {
color #ffffff
}
}
@import "task-list.styl"
}
}

View File

@ -58,10 +58,6 @@
.text-info { .text-info {
padding 10px padding 10px
.el-tag {
margin-right 10px
}
} }
} }

View File

@ -1,7 +1,7 @@
import axios from 'axios' import axios from 'axios'
import {getAdminToken, getSessionId, getUserToken} from "@/store/session"; import {getAdminToken, getSessionId, getUserToken} from "@/store/session";
axios.defaults.timeout = 30000 axios.defaults.timeout = 180000
axios.defaults.baseURL = process.env.VUE_APP_API_HOST axios.defaults.baseURL = process.env.VUE_APP_API_HOST
axios.defaults.withCredentials = true; axios.defaults.withCredentials = true;
axios.defaults.headers.post['Content-Type'] = 'application/json' axios.defaults.headers.post['Content-Type'] = 'application/json'

View File

@ -1,29 +1,19 @@
<template> <template>
<div> <div>
<div class="page-sd"> <div class="page-dall">
<div class="inner custom-scroll"> <div class="inner custom-scroll">
<div class="sd-box"> <div class="sd-box">
<h2>Stable Diffusion 创作中心</h2> <h2>DALL-E 创作中心</h2>
<div class="sd-params" :style="{ height: mjBoxHeight + 'px' }"> <div class="sd-params" :style="{ height: paramBoxHeight + 'px' }">
<el-form :model="params" label-width="80px" label-position="left"> <el-form :model="params" label-width="80px" label-position="left">
<div class="param-line" style="padding-top: 10px"> <div class="param-line" style="padding-top: 10px">
<el-form-item label="采样方法"> <el-form-item label="图片质量">
<template #default> <template #default>
<div class="form-item-inner"> <div class="form-item-inner">
<el-select v-model="params.sampler" style="width:176px"> <el-select v-model="params.quality" style="width:176px">
<el-option v-for="item in samplers" :label="item" :value="item" :key="item"/> <el-option v-for="v in qualities" :label="v.name" :value="v.value" :key="v.value"/>
</el-select> </el-select>
<el-tooltip
effect="light"
content="出图效果比较好的一般是 Euler 和 DPM 系列算法"
raw-content
placement="right"
>
<el-icon class="info-icon">
<InfoFilled/>
</el-icon>
</el-tooltip>
</div> </div>
</template> </template>
</el-form-item> </el-form-item>
@ -33,27 +23,24 @@
<el-form-item label="图片尺寸"> <el-form-item label="图片尺寸">
<template #default> <template #default>
<div class="form-item-inner"> <div class="form-item-inner">
<el-row :gutter="20"> <el-select v-model="params.size" style="width:176px">
<el-col :span="12"> <el-option v-for="v in sizes" :label="v" :value="v" :key="v"/>
<el-input v-model.number="params.width" placeholder="图片宽度"/> </el-select>
</el-col>
<el-col :span="12">
<el-input v-model.number="params.height" placeholder="图片高度"/>
</el-col>
</el-row>
</div> </div>
</template> </template>
</el-form-item> </el-form-item>
</div> </div>
<div class="param-line"> <div class="param-line">
<el-form-item label="迭代步数"> <el-form-item label="图片样式">
<template #default> <template #default>
<div class="form-item-inner"> <div class="form-item-inner">
<el-input v-model.number="params.steps"/> <el-select v-model="params.style" style="width:176px">
<el-option v-for="v in styles" :label="v.name" :value="v.value" :key="v.value"/>
</el-select>
<el-tooltip <el-tooltip
effect="light" effect="light"
content="值越大则代表细节越多,同时也意味着出图速度越慢" content="生动使模型倾向于生成超真实和戏剧性的图像"
raw-content raw-content
placement="right" placement="right"
> >
@ -66,162 +53,6 @@
</el-form-item> </el-form-item>
</div> </div>
<div class="param-line">
<el-form-item label="引导系数">
<template #default>
<div class="form-item-inner">
<el-input v-model.number="params.cfg_scale"/>
<el-tooltip
effect="light"
content="提示词引导系数,图像在多大程度上服从提示词<br/> 较低值会产生更有创意的结果"
raw-content
placement="right"
>
<el-icon class="info-icon">
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-form-item label="随机因子">
<template #default>
<div class="form-item-inner">
<el-input v-model.number="params.seed"/>
<el-tooltip
effect="light"
content="随机数种子,相同的种子会得到相同的结果<br/> 设置为 -1 则每次随机生成种子"
raw-content
placement="right"
>
<el-icon class="info-icon">
<InfoFilled/>
</el-icon>
</el-tooltip>
<el-tooltip
effect="light"
content="使用随机数"
raw-content
placement="right"
>
<el-icon @click="params.seed = -1" class="info-icon">
<Orange/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-form-item label="高清修复">
<template #default>
<div class="form-item-inner">
<el-switch v-model="params.hd_fix" style="--el-switch-on-color: #47fff1;" size="large"/>
<el-tooltip
effect="light"
content="先以较小的分辨率生成图像,接着方法图像<br />然后在不更改构图的情况下再修改细节"
raw-content
placement="right"
>
<el-icon style="margin-left: 10px; top: 12px">
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div v-show="params.hd_fix">
<div class="param-line">
<el-form-item label="重绘幅度">
<template #default>
<div class="form-item-inner">
<el-slider v-model.number="params.hd_redraw_rate" :max="1" :step="0.1"
style="width: 180px;--el-slider-main-bg-color:#47fff1"/>
<el-tooltip
effect="light"
content="决定算法对图像内容的影响程度<br />较大的值将得到越有创意的图像"
raw-content
placement="right"
>
<el-icon class="info-icon">
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-form-item label="放大算法">
<template #default>
<div class="form-item-inner">
<el-select v-model="params.hd_scale_alg" style="width:176px">
<el-option v-for="item in scaleAlg" :label="item" :value="item" :key="item"/>
</el-select>
<el-tooltip
effect="light"
content="高清修复放大算法主流算法有Latent和ESRGAN_4x"
raw-content
placement="right"
>
<el-icon class="info-icon">
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-form-item label="放大倍数">
<template #default>
<div class="form-item-inner">
<el-input v-model.number="params.hd_scale"/>
<el-tooltip
effect="light"
content="随机数种子,相同的种子会得到相同的结果<br/> 设置为 -1 则每次随机生成种子"
raw-content
placement="right"
>
<el-icon class="info-icon">
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-form-item label="迭代步数">
<template #default>
<div class="form-item-inner">
<el-input v-model.number="params.hd_steps"/>
<el-tooltip
effect="light"
content="重绘迭代步数如果设置为0则设置跟原图相同的迭代步数"
raw-content
placement="right"
>
<el-icon class="info-icon">
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
</div>
<div class="param-line"> <div class="param-line">
<el-input <el-input
v-model="params.prompt" v-model="params.prompt"
@ -232,36 +63,23 @@
/> />
</div> </div>
<div class="param-line pt">
<span>反向提示词</span>
<el-tooltip
effect="light"
content="不希望出现的元素,下面给了默认的起手式"
placement="right"
>
<el-icon class="info-icon">
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
<div class="param-line">
<el-input
v-model="params.neg_prompt"
:autosize="{ minRows: 4, maxRows: 6 }"
type="textarea"
placeholder="反向提示词"
/>
</div>
<div class="text-info"> <div class="text-info">
<el-tag>每次绘图消耗{{ sdPower }}算力</el-tag> <el-row :gutter="10">
<el-tag type="success">当前可用算力{{ power }}</el-tag> <el-col :span="12">
<el-tag>每次绘图消耗{{ dallPower }}算力</el-tag>
</el-col>
<el-col :span="12">
<el-tag type="success">当前可用{{ power }}算力</el-tag>
</el-col>
</el-row>
</div> </div>
</el-form> </el-form>
</div> </div>
<div class="submit-btn"> <div class="submit-btn">
<el-button color="#47fff1" :dark="false" round @click="generate">立即生成</el-button> <el-button color="#47fff1" :dark="false" round @click="generate">
立即生成
</el-button>
</div> </div>
</div> </div>
<div class="task-list-box" @scrollend="handleScrollEnd"> <div class="task-list-box" @scrollend="handleScrollEnd">
@ -270,32 +88,9 @@
<h2>任务列表</h2> <h2>任务列表</h2>
<div class="running-job-list"> <div class="running-job-list">
<ItemList :items="runningJobs" v-if="runningJobs.length > 0" :width="240"> <ItemList :items="runningJobs" v-if="runningJobs.length > 0" :width="240">
<template #default="scope"> <template #default>
<div class="job-item"> <div class="job-item">
<div v-if="scope.item.progress > 0" class="job-item-inner"> <el-image fit="cover">
<el-image :src="scope.item['img_url']"
fit="cover"
loading="lazy">
<template #placeholder>
<div class="image-slot">
正在加载图片
</div>
</template>
<template #error>
<div class="image-slot">
<el-icon v-if="scope.item['img_url'] !== ''">
<Picture/>
</el-icon>
</div>
</template>
</el-image>
<div class="progress">
<el-progress type="circle" :percentage="scope.item.progress" :width="100" color="#47fff1"/>
</div>
</div>
<el-image fit="cover" v-else>
<template #error> <template #error>
<div class="image-slot"> <div class="image-slot">
<i class="iconfont icon-quick-start"></i> <i class="iconfont icon-quick-start"></i>
@ -308,16 +103,38 @@
</ItemList> </ItemList>
<el-empty :image-size="100" v-else/> <el-empty :image-size="100" v-else/>
</div> </div>
<h2>创作记录</h2> <h2>创作记录</h2>
<div class="finish-job-list" v-loading="loading" element-loading-background="rgba(0, 0, 0, 0.5)"> <div class="finish-job-list" v-loading="loading" element-loading-background="rgba(0, 0, 0, 0.5)">
<div v-if="finishedJobs.length > 0"> <div v-if="finishedJobs.length > 0">
<ItemList :items="finishedJobs" :width="240" :gap="16"> <ItemList :items="finishedJobs" :width="240" :gap="16">
<template #default="scope"> <template #default="scope">
<div class="job-item animate" @click="showTask(scope.item)"> <div class="job-item">
<el-image <el-image v-if="scope.item['img_url']"
:src="scope.item['img_url']+'?imageView2/1/w/240/h/240/q/75'" :src="scope.item['img_url']+'?imageView2/1/w/240/h/240/q/75'"
fit="cover" fit="cover"
loading="lazy"> :preview-src-list="[scope.item['img_url']]"
loading="lazy">
<template #placeholder>
<div class="image-slot">
正在加载图片
</div>
</template>
<template #error>
<div class="image-slot">
<el-icon>
<Picture/>
</el-icon>
</div>
</template>
</el-image>
<el-image v-else
:src="scope.item['org_url']"
fit="cover"
:preview-src-list="[scope.item['org_url']]"
loading="lazy">
<template #placeholder> <template #placeholder>
<div class="image-slot"> <div class="image-slot">
正在加载图片 正在加载图片
@ -334,15 +151,27 @@
</el-image> </el-image>
<div class="remove"> <div class="remove">
<el-button type="danger" :icon="Delete" @click="removeImage($event,scope.item)" circle/> <el-tooltip content="删除" placement="top" effect="light">
<el-button type="warning" v-if="scope.item.publish" <el-button type="danger" :icon="Delete" @click="removeImage($event,scope.item)" circle/>
@click="publishImage($event,scope.item, false)" </el-tooltip>
circle> <el-tooltip content="分享" placement="top" effect="light" v-if="scope.item.publish">
<i class="iconfont icon-cancel-share"></i> <el-button type="warning"
</el-button> @click="publishImage($event,scope.item, false)"
<el-button type="success" v-else @click="publishImage($event,scope.item, true)" circle> circle>
<i class="iconfont icon-share-bold"></i> <i class="iconfont icon-cancel-share"></i>
</el-button> </el-button>
</el-tooltip>
<el-tooltip content="取消分享" placement="top" effect="light" v-else>
<el-button type="success" @click="publishImage($event,scope.item, true)" circle>
<i class="iconfont icon-share-bold"></i>
</el-button>
</el-tooltip>
<el-tooltip content="复制提示词" placement="top" effect="light">
<el-button type="info" circle class="copy-prompt" :data-clipboard-text="scope.item.prompt">
<i class="iconfont icon-file"></i>
</el-button>
</el-tooltip>
</div> </div>
</div> </div>
</template> </template>
@ -361,118 +190,6 @@
</div><!-- end task list box --> </div><!-- end task list box -->
</div> </div>
<!-- 任务详情弹框 -->
<el-dialog v-model="showTaskDialog" title="绘画任务详情" :fullscreen="true">
<el-row :gutter="20">
<el-col :span="16">
<div class="img-container" :style="{maxHeight: fullImgHeight+'px'}">
<el-image :src="item['img_url']" fit="contain"/>
</div>
</el-col>
<el-col :span="8">
<div class="task-info">
<div class="info-line">
<el-divider>
正向提示词
</el-divider>
<div class="prompt">
<span>{{ item.prompt }}</span>
<el-icon class="copy-prompt-sd" :data-clipboard-text="item.prompt">
<DocumentCopy/>
</el-icon>
</div>
</div>
<div class="info-line">
<el-divider>
反向提示词
</el-divider>
<div class="prompt">
<span>{{ item.params.neg_prompt }}</span>
<el-icon class="copy-prompt-sd" :data-clipboard-text="item.params.neg_prompt">
<DocumentCopy/>
</el-icon>
</div>
</div>
<div class="info-line">
<div class="wrapper">
<label>采样方法</label>
<div class="item-value">{{ item.params.sampler }}</div>
</div>
</div>
<div class="info-line">
<div class="wrapper">
<label>图片尺寸</label>
<div class="item-value">{{ item.params.width }} x {{ item.params.height }}</div>
</div>
</div>
<div class="info-line">
<div class="wrapper">
<label>迭代步数</label>
<div class="item-value">{{ item.params.steps }}</div>
</div>
</div>
<div class="info-line">
<div class="wrapper">
<label>引导系数</label>
<div class="item-value">{{ item.params.cfg_scale }}</div>
</div>
</div>
<div class="info-line">
<div class="wrapper">
<label>随机因子</label>
<div class="item-value">{{ item.params.seed }}</div>
</div>
</div>
<div v-if="item.params.hd_fix">
<el-divider>
高清修复
</el-divider>
<div class="info-line">
<div class="wrapper">
<label>重绘幅度</label>
<div class="item-value">{{ item.params.hd_redraw_rate }}</div>
</div>
</div>
<div class="info-line">
<div class="wrapper">
<label>放大算法</label>
<div class="item-value">{{ item.params.hd_scale_alg }}</div>
</div>
</div>
<div class="info-line">
<div class="wrapper">
<label>放大倍数</label>
<div class="item-value">{{ item.params.hd_scale }}</div>
</div>
</div>
<div class="info-line">
<div class="wrapper">
<label>迭代步数</label>
<div class="item-value">{{ item.params.hd_steps }}</div>
</div>
</div>
</div>
<div class="copy-params">
<el-button type="primary" round @click="copyParams(item)">画一张同款的</el-button>
</div>
</div>
</el-col>
</el-row>
</el-dialog>
</div> </div>
<login-dialog :show="showLoginDialog" @hide="showLoginDialog = false" @success="initData"/> <login-dialog :show="showLoginDialog" @hide="showLoginDialog = false" @success="initData"/>
@ -481,59 +198,93 @@
<script setup> <script setup>
import {onMounted, onUnmounted, ref} from "vue" import {onMounted, onUnmounted, ref} from "vue"
import {Delete, DocumentCopy, InfoFilled, Orange, Picture} from "@element-plus/icons-vue"; import {Delete, InfoFilled, Picture} from "@element-plus/icons-vue";
import {httpGet, httpPost} from "@/utils/http"; import {httpGet, httpPost} from "@/utils/http";
import {ElMessage, ElMessageBox, ElNotification} from "element-plus"; import {ElMessage, ElMessageBox, ElNotification} from "element-plus";
import ItemList from "@/components/ItemList.vue"; import ItemList from "@/components/ItemList.vue";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
import {checkSession} from "@/action/session"; import {checkSession} from "@/action/session";
import {useRouter} from "vue-router";
import {getSessionId} from "@/store/session";
import LoginDialog from "@/components/LoginDialog.vue"; import LoginDialog from "@/components/LoginDialog.vue";
const listBoxHeight = ref(window.innerHeight - 40) const listBoxHeight = ref(window.innerHeight - 40)
const mjBoxHeight = ref(window.innerHeight - 150) const paramBoxHeight = ref(window.innerHeight - 150)
const fullImgHeight = ref(window.innerHeight - 60)
const showTaskDialog = ref(false)
const item = ref({})
const showLoginDialog = ref(false) const showLoginDialog = ref(false)
const isLogin = ref(false) const isLogin = ref(false)
window.onresize = () => { window.onresize = () => {
listBoxHeight.value = window.innerHeight - 40 listBoxHeight.value = window.innerHeight - 40
mjBoxHeight.value = window.innerHeight - 150 paramBoxHeight.value = window.innerHeight - 150
} }
const samplers = ["Euler a", "DPM++ 2S a Karras", "DPM++ 2M Karras", "DPM++ SDE Karras", "DPM++ 2M SDE Karras"] const qualities = [
const scaleAlg = ["Latent", "ESRGAN_4x", "R-ESRGAN 4x+", "SwinIR_4x", "LDSR"] {name: "标准", value: "standard"},
{name: "高清", value: "hd"},
]
const sizes = ["1024x1024", "1792x1024", "1024x1792"]
const styles = [
{name: "生动", value: "vivid"},
{name: "自然", value: "natural"}
]
const params = ref({ const params = ref({
width: 1024, quality: "standard",
height: 1024, size: "1024x1024",
sampler: samplers[0], style: "vivid",
seed: -1, prompt: ""
steps: 20,
cfg_scale: 7,
hd_fix: false,
hd_redraw_rate: 0.7,
hd_scale: 2,
hd_scale_alg: scaleAlg[0],
hd_steps: 0,
prompt: "",
neg_prompt: "nsfw, paintings,low quality,easynegative,ng_deepnegative ,lowres,bad anatomy,bad hands,bad feet",
}) })
const runningJobs = ref([])
const finishedJobs = ref([]) const finishedJobs = ref([])
const router = useRouter() const runningJobs = ref([])
//
const _params = router.currentRoute.value.params["copyParams"]
if (_params) {
params.value = JSON.parse(_params)
}
const power = ref(0) const power = ref(0)
const sdPower = ref(0) // SD const dallPower = ref(0) // SD
const clipboard = ref(null)
const userId = ref(0)
onMounted(() => {
initData()
clipboard.value = new Clipboard('.copy-prompt');
clipboard.value.on('success', () => {
ElMessage.success("复制成功!");
})
clipboard.value.on('error', () => {
ElMessage.error('复制失败!');
})
httpGet("/api/config/get?key=system").then(res => {
dallPower.value = res.data["dall_power"]
}).catch(e => {
ElMessage.error("获取系统配置失败:" + e.message)
})
})
onUnmounted(() => {
clipboard.value.destroy()
if (socket.value !== null) {
socket.value.close()
socket.value = null
}
})
const initData = () => {
checkSession().then(user => {
power.value = user['power']
userId.value = user.id
isLogin.value = true
fetchRunningJobs()
fetchFinishJobs(1)
connect()
}).catch(() => {
loading.value = false
});
}
const handleScrollEnd = () => {
if (isOver.value === true) {
return
}
page.value += 1
fetchFinishJobs(page.value)
}
const socket = ref(null) const socket = ref(null)
const userId = ref(0)
const heartbeatHandle = ref(null) const heartbeatHandle = ref(null)
const connect = () => { const connect = () => {
let host = process.env.VUE_APP_WS_HOST let host = process.env.VUE_APP_WS_HOST
@ -558,7 +309,7 @@ const connect = () => {
}); });
} }
const _socket = new WebSocket(host + `/api/sd/client?user_id=${userId.value}`); const _socket = new WebSocket(host + `/api/dall/client?user_id=${userId.value}`);
_socket.addEventListener('open', () => { _socket.addEventListener('open', () => {
socket.value = _socket; socket.value = _socket;
@ -568,10 +319,17 @@ const connect = () => {
_socket.addEventListener('message', event => { _socket.addEventListener('message', event => {
if (event.data instanceof Blob) { if (event.data instanceof Blob) {
fetchRunningJobs() const reader = new FileReader();
isOver.value = false reader.readAsText(event.data, "UTF-8")
page.value = 1 reader.onload = () => {
fetchFinishJobs(page.value) const message = String(reader.result)
if (message === "FINISH") {
page.value = 1
fetchFinishJobs(page.value)
isOver.value = false
}
fetchRunningJobs()
}
} }
}); });
@ -582,47 +340,9 @@ const connect = () => {
}) })
} }
const clipboard = ref(null)
onMounted(() => {
initData()
clipboard.value = new Clipboard('.copy-prompt-sd');
clipboard.value.on('success', () => {
ElMessage.success("复制成功!");
})
clipboard.value.on('error', () => {
ElMessage.error('复制失败!');
})
httpGet("/api/config/get?key=system").then(res => {
sdPower.value = res.data["sd_power"]
}).catch(e => {
ElMessage.error("获取系统配置失败:" + e.message)
})
})
onUnmounted(() => {
clipboard.value.destroy()
socket.value = null
})
const initData = () => {
checkSession().then(user => {
power.value = user['power']
userId.value = user.id
isLogin.value = true
fetchRunningJobs()
fetchFinishJobs()
connect()
}).catch(() => {
loading.value = false
});
}
const fetchRunningJobs = () => { const fetchRunningJobs = () => {
// //
httpGet(`/api/sd/jobs?status=0`).then(res => { httpGet(`/api/dall/jobs?status=0`).then(res => {
const jobs = res.data const jobs = res.data
const _jobs = [] const _jobs = []
for (let i = 0; i < jobs.length; i++) { for (let i = 0; i < jobs.length; i++) {
@ -633,7 +353,7 @@ const fetchRunningJobs = () => {
message: `任务ID${jobs[i]['task_id']}<br />原因:${jobs[i]['err_msg']}`, message: `任务ID${jobs[i]['task_id']}<br />原因:${jobs[i]['err_msg']}`,
type: 'error', type: 'error',
}) })
power.value += sdPower.value power.value += dallPower.value
continue continue
} }
_jobs.push(jobs[i]) _jobs.push(jobs[i])
@ -644,14 +364,6 @@ const fetchRunningJobs = () => {
}) })
} }
const handleScrollEnd = () => {
if (isOver.value === true) {
return
}
page.value += 1
fetchFinishJobs(page.value)
}
const page = ref(1) const page = ref(1)
const pageSize = ref(15) const pageSize = ref(15)
const isOver = ref(false) const isOver = ref(false)
@ -659,7 +371,7 @@ const loading = ref(false)
// //
const fetchFinishJobs = (page) => { const fetchFinishJobs = (page) => {
loading.value = true loading.value = true
httpGet(`/api/sd/jobs?status=1&page=${page}&page_size=${pageSize.value}`).then(res => { httpGet(`/api/dall/jobs?status=1&page=${page}&page_size=${pageSize.value}`).then(res => {
if (res.data.length < pageSize.value) { if (res.data.length < pageSize.value) {
isOver.value = true isOver.value = true
} }
@ -688,29 +400,14 @@ const generate = () => {
showLoginDialog.value = true showLoginDialog.value = true
return return
} }
httpPost("/api/dall/image", params.value).then(() => {
if (params.value.seed === '') { ElMessage.success("任务执行成功!")
params.value.seed = -1 power.value -= dallPower.value
}
params.value.session_id = getSessionId()
httpPost("/api/sd/image", params.value).then(() => {
ElMessage.success("绘画任务推送成功,请耐心等待任务执行...")
power.value -= sdPower.value
}).catch(e => { }).catch(e => {
ElMessage.error("任务推送失败:" + e.message) ElMessage.error("任务执行失败:" + e.message)
}) })
} }
const showTask = (row) => {
item.value = row
showTaskDialog.value = true
}
const copyParams = (row) => {
params.value = row.params
showTaskDialog.value = false
}
const removeImage = (event, item) => { const removeImage = (event, item) => {
event.stopPropagation() event.stopPropagation()
ElMessageBox.confirm( ElMessageBox.confirm(
@ -722,8 +419,9 @@ const removeImage = (event, item) => {
type: 'warning', type: 'warning',
} }
).then(() => { ).then(() => {
httpPost("/api/sd/remove", {id: item.id, img_url: item.img_url, user_id: userId.value}).then(() => { httpPost("/api/dall/remove", {id: item.id, img_url: item.img_url, user_id: userId.value}).then(() => {
ElMessage.success("任务删除成功") ElMessage.success("任务删除成功")
fetchFinishJobs(1)
}).catch(e => { }).catch(e => {
ElMessage.error("任务删除失败:" + e.message) ElMessage.error("任务删除失败:" + e.message)
}) })
@ -738,7 +436,7 @@ const publishImage = (event, item, action) => {
if (action === false) { if (action === false) {
text = "取消发布" text = "取消发布"
} }
httpPost("/api/sd/publish", {id: item.id, action: action}).then(() => { httpPost("/api/dall/publish", {id: item.id, action: action}).then(() => {
ElMessage.success(text + "成功") ElMessage.success(text + "成功")
item.publish = action item.publish = action
}).catch(e => { }).catch(e => {
@ -749,6 +447,6 @@ const publishImage = (event, item, action) => {
</script> </script>
<style lang="stylus"> <style lang="stylus">
@import "@/assets/css/image-sd.styl" @import "@/assets/css/image-dall.styl"
@import "@/assets/css/custom-scroll.styl" @import "@/assets/css/custom-scroll.styl"
</style> </style>

View File

@ -4,7 +4,7 @@
<div class="mj-box"> <div class="mj-box">
<h2>MidJourney 创作中心</h2> <h2>MidJourney 创作中心</h2>
<div class="mj-params" :style="{ height: mjBoxHeight + 'px' }"> <div class="mj-params" :style="{ height: paramBoxHeight + 'px' }">
<el-form :model="params" label-width="80px" label-position="left"> <el-form :model="params" label-width="80px" label-position="left">
<div class="param-line pt"> <div class="param-line pt">
<span>图片比例</span> <span>图片比例</span>
@ -607,12 +607,12 @@ import {copyObj, removeArrayItem} from "@/utils/libs";
import LoginDialog from "@/components/LoginDialog.vue"; import LoginDialog from "@/components/LoginDialog.vue";
const listBoxHeight = ref(window.innerHeight - 40) const listBoxHeight = ref(window.innerHeight - 40)
const mjBoxHeight = ref(window.innerHeight - 150) const paramBoxHeight = ref(window.innerHeight - 150)
const showLoginDialog = ref(false) const showLoginDialog = ref(false)
window.onresize = () => { window.onresize = () => {
listBoxHeight.value = window.innerHeight - 40 listBoxHeight.value = window.innerHeight - 40
mjBoxHeight.value = window.innerHeight - 150 paramBoxHeight.value = window.innerHeight - 150
} }
const rates = [ const rates = [
{css: "square", value: "1:1", text: "1:1", img: "/images/mj/rate_1_1.png"}, {css: "square", value: "1:1", text: "1:1", img: "/images/mj/rate_1_1.png"},
@ -733,9 +733,8 @@ const connect = () => {
page.value = 1 page.value = 1
fetchFinishJobs(page.value) fetchFinishJobs(page.value)
isOver.value = false isOver.value = false
} else {
fetchRunningJobs()
} }
fetchRunningJobs()
} }
} }
}); });

View File

@ -5,7 +5,7 @@
<div class="sd-box"> <div class="sd-box">
<h2>Stable Diffusion 创作中心</h2> <h2>Stable Diffusion 创作中心</h2>
<div class="sd-params" :style="{ height: mjBoxHeight + 'px' }"> <div class="sd-params" :style="{ height: paramBoxHeight + 'px' }">
<el-form :model="params" label-width="80px" label-position="left"> <el-form :model="params" label-width="80px" label-position="left">
<div class="param-line" style="padding-top: 10px"> <div class="param-line" style="padding-top: 10px">
<el-form-item label="采样方法"> <el-form-item label="采样方法">
@ -254,8 +254,14 @@
</div> </div>
<div class="text-info"> <div class="text-info">
<el-tag>每次绘图消耗{{ sdPower }}算力</el-tag> <el-row :gutter="10">
<el-tag type="success">当前可用算力{{ power }}</el-tag> <el-col :span="12">
<el-tag>单次绘图消耗{{ sdPower }}算力</el-tag>
</el-col>
<el-col :span="12">
<el-tag type="success">当前可用{{ power }}算力</el-tag>
</el-col>
</el-row>
</div> </div>
</el-form> </el-form>
@ -492,7 +498,7 @@ import {getSessionId} from "@/store/session";
import LoginDialog from "@/components/LoginDialog.vue"; import LoginDialog from "@/components/LoginDialog.vue";
const listBoxHeight = ref(window.innerHeight - 40) const listBoxHeight = ref(window.innerHeight - 40)
const mjBoxHeight = ref(window.innerHeight - 150) const paramBoxHeight = ref(window.innerHeight - 150)
const fullImgHeight = ref(window.innerHeight - 60) const fullImgHeight = ref(window.innerHeight - 60)
const showTaskDialog = ref(false) const showTaskDialog = ref(false)
const item = ref({}) const item = ref({})
@ -501,7 +507,7 @@ const isLogin = ref(false)
window.onresize = () => { window.onresize = () => {
listBoxHeight.value = window.innerHeight - 40 listBoxHeight.value = window.innerHeight - 40
mjBoxHeight.value = window.innerHeight - 150 paramBoxHeight.value = window.innerHeight - 150
} }
const samplers = ["Euler a", "DPM++ 2S a Karras", "DPM++ 2M Karras", "DPM++ SDE Karras", "DPM++ 2M SDE Karras"] const samplers = ["Euler a", "DPM++ 2S a Karras", "DPM++ 2M Karras", "DPM++ SDE Karras", "DPM++ 2M SDE Karras"]
const scaleAlg = ["Latent", "ESRGAN_4x", "R-ESRGAN 4x+", "SwinIR_4x", "LDSR"] const scaleAlg = ["Latent", "ESRGAN_4x", "R-ESRGAN 4x+", "SwinIR_4x", "LDSR"]
@ -576,9 +582,8 @@ const connect = () => {
page.value = 1 page.value = 1
fetchFinishJobs(page.value) fetchFinishJobs(page.value)
isOver.value = false isOver.value = false
} else {
fetchRunningJobs()
} }
fetchRunningJobs()
} }
} }
}); });