feat: stable diffusion page is ready

This commit is contained in:
RockYang 2023-09-28 18:09:45 +08:00
parent 75c5ebbffa
commit c5776ce41f
23 changed files with 1730 additions and 779 deletions

View File

@ -157,7 +157,9 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
var tokenString string
if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API
tokenString = c.GetHeader(types.AdminAuthHeader)
} else if c.Request.URL.Path == "/api/chat/new" || c.Request.URL.Path == "/api/mj/client" {
} else if c.Request.URL.Path == "/api/chat/new" ||
c.Request.URL.Path == "/api/mj/client" ||
c.Request.URL.Path == "/api/sd/client" {
tokenString = c.Query("token")
} else {
tokenString = c.GetHeader(types.UserAuthHeader)

View File

@ -101,13 +101,13 @@ type ModelAPIConfig struct {
}
type SystemConfig struct {
Title string `json:"title"`
AdminTitle string `json:"admin_title"`
Models []string `json:"models"`
UserInitCalls int `json:"user_init_calls"` // 新用户注册默认总送多少次调用
InitImgCalls int `json:"init_img_calls"`
VipMonthCalls int `json:"vip_month_calls"` // 会员每个赠送的调用次数
EnabledRegister bool `json:"enabled_register"`
EnabledMsgService bool `json:"enabled_msg_service"`
EnabledDraw bool `json:"enabled_draw"` // 启动 AI 绘画功能
Title string `json:"title"`
AdminTitle string `json:"admin_title"`
Models []string `json:"models"`
UserInitCalls int `json:"user_init_calls"` // 新用户注册默认总送多少次调用
InitImgCalls int `json:"init_img_calls"`
VipMonthCalls int `json:"vip_month_calls"` // 会员每个赠送的调用次数
EnabledRegister bool `json:"enabled_register"`
EnabledMsg bool `json:"enabled_msg"` // 启用短信验证码服务
EnabledDraw bool `json:"enabled_draw"` // 启动 AI 绘画功能
}

View File

@ -40,7 +40,7 @@ type MjTask struct {
}
type SdTask struct {
Id int `json:"id"`
Id int `json:"id"` // job 数据库ID
SessionId string `json:"session_id"`
Src TaskSrc `json:"src"`
Type TaskType `json:"type"`
@ -52,18 +52,18 @@ type SdTask struct {
type SdTaskParams struct {
TaskId string `json:"task_id"`
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt"`
Steps int `json:"steps"`
Sampler string `json:"sampler"`
FaceFix bool `json:"face_fix"`
CfgScale float32 `json:"cfg_scale"`
Seed int64 `json:"seed"`
Prompt string `json:"prompt"` // 提示词
NegativePrompt string `json:"negative_prompt"` // 反向提示词
Steps int `json:"steps"` // 迭代步数默认20
Sampler string `json:"sampler"` // 采样器
FaceFix bool `json:"face_fix"` // 面部修复
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
Seed int64 `json:"seed"` // 随机数种子
Height int `json:"height"`
Width int `json:"width"`
HdFix bool `json:"hd_fix"`
HdRedrawRate float32 `json:"hd_redraw_rate"`
HdScale int `json:"hd_scale"`
HdScaleAlg string `json:"hd_scale_alg"`
HdSampleNum int `json:"hd_sample_num"`
HdFix bool `json:"hd_fix"` // 启用高清修复
HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
HdScale int `json:"hd_scale"` // 放大倍数
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
HdSteps int `json:"hd_steps"` // 高清修复迭代步数
}

View File

@ -4,7 +4,6 @@ import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service/mj"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
@ -17,33 +16,25 @@ import (
"gorm.io/gorm"
"net/http"
"strings"
"sync"
"time"
)
type MidJourneyHandler struct {
BaseHandler
redis *redis.Client
db *gorm.DB
mjService *mj.Service
uploaderManager *oss.UploaderManager
lock sync.Mutex
clients *types.LMap[string, *types.WsClient]
redis *redis.Client
db *gorm.DB
mjService *mj.Service
}
func NewMidJourneyHandler(
app *core.AppServer,
client *redis.Client,
db *gorm.DB,
manager *oss.UploaderManager,
mjService *mj.Service) *MidJourneyHandler {
h := MidJourneyHandler{
redis: client,
db: db,
uploaderManager: manager,
lock: sync.Mutex{},
mjService: mjService,
clients: types.NewLMap[string, *types.WsClient](),
redis: client,
db: db,
mjService: mjService,
}
h.App = app
return &h
@ -59,9 +50,7 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
sessionId := c.Query("session_id")
client := types.NewWsClient(ws)
// 删除旧的连接
h.clients.Delete(sessionId)
h.clients.Put(sessionId, client)
h.mjService.Clients.Put(sessionId, client)
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
}
@ -156,7 +145,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.clients.Get(data.SessionId)
client := h.mjService.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
@ -212,7 +201,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.clients.Get(data.SessionId)
client := h.mjService.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
@ -283,7 +272,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.clients.Get(data.SessionId)
client := h.mjService.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}

View File

@ -1,316 +1,202 @@
package handler
//
//import (
// "chatplus/core"
// "chatplus/core/types"
// "chatplus/service"
// "chatplus/service/oss"
// "chatplus/store/model"
// "chatplus/store/vo"
// "chatplus/utils"
// "chatplus/utils/resp"
// "encoding/base64"
// "fmt"
// "github.com/gin-gonic/gin"
// "github.com/go-redis/redis/v8"
// "github.com/gorilla/websocket"
// "gorm.io/gorm"
// "net/http"
// "strings"
// "sync"
// "time"
//)
//
//type SdJobHandler struct {
// BaseHandler
// redis *redis.Client
// db *gorm.DB
// mjService *service.MjService
// uploaderManager *oss.UploaderManager
// lock sync.Mutex
// clients *types.LMap[string, *types.WsClient]
//}
//
//func NewSdJobHandler(
// app *core.AppServer,
// client *redis.Client,
// db *gorm.DB,
// manager *oss.UploaderManager,
// mjService *service.MjService) *MidJourneyHandler {
// h := MidJourneyHandler{
// redis: client,
// db: db,
// uploaderManager: manager,
// lock: sync.Mutex{},
// mjService: mjService,
// clients: types.NewLMap[string, *types.WsClient](),
// }
// h.App = app
// return &h
//}
//
//// Client WebSocket 客户端,用于通知任务状态变更
//func (h *SdJobHandler) Client(c *gin.Context) {
// ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
// if err != nil {
// logger.Error(err)
// return
// }
//
// sessionId := c.Query("session_id")
// client := types.NewWsClient(ws)
// // 删除旧的连接
// h.clients.Delete(sessionId)
// h.clients.Put(sessionId, client)
// logger.Infof("New websocket connected, IP: %s", c.ClientIP())
//}
//
//type sdNotifyData struct {
// TaskId string
// ImageName string
// ImageData string
// Progress int
// Seed string
// Success bool
// Message string
//}
//
//func (h *SdJobHandler) Notify(c *gin.Context) {
// token := c.GetHeader("Authorization")
// if token != h.App.Config.ExtConfig.Token {
// resp.NotAuth(c)
// return
// }
// var data sdNotifyData
// if err := c.ShouldBindJSON(&data); err != nil || data.TaskId == "" {
// resp.ERROR(c, types.InvalidArgs)
// return
// }
// logger.Debugf("收到 MidJourney 回调请求:%+v", data)
//
// h.lock.Lock()
// defer h.lock.Unlock()
//
// err, finished := h.notifyHandler(c, data)
// if err != nil {
// resp.ERROR(c, err.Error())
// return
// }
//
// // 解除任务锁定
// if finished && (data.Progress == 100) {
// h.redis.Del(c, service.MjRunningJobKey)
// }
// resp.SUCCESS(c)
//
//}
//
//func (h *SdJobHandler) notifyHandler(c *gin.Context, data sdNotifyData) (error, bool) {
// taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
// if err != nil { // 过期任务,丢弃
// logger.Warn("任务已过期:", err)
// return nil, true
// }
//
// var task types.SdTask
// err = utils.JsonDecode(taskString, &task)
// if err != nil { // 非标准任务,丢弃
// logger.Warn("任务解析失败:", err)
// return nil, false
// }
//
// var job model.SdJob
// res := h.db.Where("id = ?", task.Id).First(&job)
// if res.Error != nil {
// logger.Warn("非法任务:", res.Error)
// return nil, false
// }
// job.Params = utils.JsonEncode(task.Params)
// job.ReferenceId = data.ImageData
// job.Progress = data.Progress
// job.Prompt = data.Prompt
// job.Hash = data.Image.Hash
//
// // 任务完成,将最终的图片下载下来
// if data.Progress == 100 {
// imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
// if err != nil {
// logger.Error("error with download img: ", err.Error())
// return err, false
// }
// job.ImgURL = imgURL
// } else {
// // 临时图片直接保存,访问的时候使用代理进行转发
// job.ImgURL = data.Image.URL
// }
// res = h.db.Updates(&job)
// if res.Error != nil {
// logger.Error("error with update job: ", res.Error)
// return res.Error, false
// }
//
// var jobVo vo.MidJourneyJob
// err := utils.CopyObject(job, &jobVo)
// if err == nil {
// if data.Progress < 100 {
// image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL)
// if err == nil {
// jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
// }
// }
//
// // 推送任务到前端
// client := h.clients.Get(task.SessionId)
// if client != nil {
// utils.ReplyChunkMessage(client, jobVo)
// }
// }
//
// // 更新用户剩余绘图次数
// if data.Progress == 100 {
// h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
// }
//
// return nil, true
//}
//
//func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
// user, err := utils.GetLoginUser(c, h.db)
// if err != nil {
// resp.NotAuth(c)
// return false
// }
//
// if user.ImgCalls <= 0 {
// resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
// return false
// }
//
// return true
//
//}
//
//// Image 创建一个绘画任务
//func (h *SdJobHandler) Image(c *gin.Context) {
// var data struct {
// SessionId string `json:"session_id"`
// Prompt string `json:"prompt"`
// Rate string `json:"rate"`
// Model string `json:"model"`
// Chaos int `json:"chaos"`
// Raw bool `json:"raw"`
// Seed int64 `json:"seed"`
// Stylize int `json:"stylize"`
// Img string `json:"img"`
// Weight float32 `json:"weight"`
// }
// if err := c.ShouldBindJSON(&data); err != nil {
// resp.ERROR(c, types.InvalidArgs)
// return
// }
// if !h.checkLimits(c) {
// return
// }
//
// var prompt = data.Prompt
// if data.Rate != "" && !strings.Contains(prompt, "--ar") {
// prompt += " --ar " + data.Rate
// }
// if data.Seed > 0 && !strings.Contains(prompt, "--seed") {
// prompt += fmt.Sprintf(" --seed %d", data.Seed)
// }
// if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") {
// prompt += fmt.Sprintf(" --s %d", data.Stylize)
// }
// if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
// prompt += fmt.Sprintf(" --c %d", data.Chaos)
// }
// if data.Img != "" {
// prompt = fmt.Sprintf("%s %s", data.Img, prompt)
// if data.Weight > 0 {
// prompt += fmt.Sprintf(" --iw %f", data.Weight)
// }
// }
// if data.Raw {
// prompt += " --style raw"
// }
// if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
// prompt += data.Model
// }
//
// idValue, _ := c.Get(types.LoginUserID)
// userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
// job := model.MidJourneyJob{
// Type: service.Image.String(),
// UserId: userId,
// Progress: 0,
// Prompt: prompt,
// CreatedAt: time.Now(),
// }
// if res := h.db.Create(&job); res.Error != nil {
// resp.ERROR(c, "添加任务失败:"+res.Error.Error())
// return
// }
//
// h.mjService.PushTask(service.MjTask{
// Id: int(job.Id),
// SessionId: data.SessionId,
// Src: service.TaskSrcImg,
// Type: service.Image,
// Prompt: prompt,
// UserId: userId,
// })
//
// var jobVo vo.MidJourneyJob
// err := utils.CopyObject(job, &jobVo)
// if err == nil {
// // 推送任务到前端
// client := h.clients.Get(data.SessionId)
// if client != nil {
// utils.ReplyChunkMessage(client, jobVo)
// }
// }
// resp.SUCCESS(c)
//}
//
//// JobList 获取 MJ 任务列表
//func (h *SdJobHandler) JobList(c *gin.Context) {
// status := h.GetInt(c, "status", 0)
// var items []model.MidJourneyJob
// var res *gorm.DB
// userId, _ := c.Get(types.LoginUserID)
// if status == 1 {
// res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items)
// } else {
// res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items)
// }
// if res.Error != nil {
// resp.ERROR(c, types.NoData)
// return
// }
//
// var jobs = make([]vo.MidJourneyJob, 0)
// for _, item := range items {
// var job vo.MidJourneyJob
// err := utils.CopyObject(item, &job)
// if err != nil {
// continue
// }
// if item.Progress < 100 {
// // 30 分钟还没完成的任务直接删除
// if time.Now().Sub(item.CreatedAt) > time.Minute*30 {
// h.db.Delete(&item)
// continue
// }
// if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
// image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL)
// if err == nil {
// job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
// }
// }
// }
// jobs = append(jobs, job)
// }
// resp.SUCCESS(c, jobs)
//}
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service/sd"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"encoding/base64"
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"time"
)
type SdJobHandler struct {
BaseHandler
redis *redis.Client
db *gorm.DB
service *sd.Service
}
func NewSdJobHandler(app *core.AppServer, redisCli *redis.Client, db *gorm.DB, service *sd.Service) *SdJobHandler {
h := SdJobHandler{
redis: redisCli,
db: db,
service: service,
}
h.App = app
return &h
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *SdJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
sessionId := c.Query("session_id")
client := types.NewWsClient(ws)
// 删除旧的连接
h.service.Clients.Put(sessionId, client)
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
}
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.NotAuth(c)
return false
}
if user.ImgCalls <= 0 {
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
return false
}
return true
}
// Image 创建一个绘画任务
func (h *SdJobHandler) Image(c *gin.Context) {
if !h.App.Config.SdConfig.Enabled {
resp.ERROR(c, "Stable Diffusion service is disabled")
return
}
if !h.checkLimits(c) {
return
}
var data struct {
SessionId string `json:"session_id"`
types.SdTaskParams
}
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Width <= 0 {
data.Width = 512
}
if data.Height <= 0 {
data.Height = 512
}
if data.CfgScale <= 0 {
data.CfgScale = 7
}
if data.Seed == 0 {
data.Seed = -1
}
if data.Steps <= 0 {
data.Steps = 20
}
if data.Sampler == "" {
data.Sampler = "Euler a"
}
idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
params := types.SdTaskParams{
TaskId: fmt.Sprintf("task(%s)", utils.RandString(15)),
Prompt: data.Prompt,
NegativePrompt: data.NegativePrompt,
Steps: data.Steps,
Sampler: data.Sampler,
FaceFix: data.FaceFix,
CfgScale: data.CfgScale,
Seed: data.Seed,
Height: data.Height,
Width: data.Width,
HdFix: data.HdFix,
HdRedrawRate: data.HdRedrawRate,
HdScale: data.HdScale,
HdScaleAlg: data.HdScaleAlg,
HdSteps: data.HdSteps,
}
job := model.SdJob{
UserId: userId,
Type: types.TaskImage.String(),
TaskId: params.TaskId,
Params: utils.JsonEncode(params),
Prompt: data.Prompt,
Progress: 0,
Started: false,
CreatedAt: time.Now(),
}
res := h.db.Create(&job)
if res.Error != nil {
resp.ERROR(c, "error with save job: "+res.Error.Error())
return
}
h.service.PushTask(types.SdTask{
Id: int(job.Id),
SessionId: data.SessionId,
Src: types.TaskSrcImg,
Type: types.TaskImage,
Prompt: data.Prompt,
Params: params,
UserId: userId,
})
var jobVo vo.SdJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.service.Clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
resp.SUCCESS(c)
}
// JobList 获取 MJ 任务列表
func (h *SdJobHandler) JobList(c *gin.Context) {
status := h.GetInt(c, "status", 0)
var items []model.SdJob
var res *gorm.DB
userId, _ := c.Get(types.LoginUserID)
if status == 1 {
res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items)
} else {
res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items)
}
if res.Error != nil {
resp.ERROR(c, types.NoData)
return
}
var jobs = make([]vo.SdJob, 0)
for _, item := range items {
var job vo.SdJob
err := utils.CopyObject(item, &job)
if err != nil {
continue
}
if item.Progress < 100 {
// 30 分钟还没完成的任务直接删除
if time.Now().Sub(item.CreatedAt) > time.Minute*30 {
h.db.Delete(&item)
continue
}
if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL)
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
}
jobs = append(jobs, job)
}
resp.SUCCESS(c, jobs)
}

View File

@ -66,5 +66,5 @@ type statusVo struct {
// Status check if the message service is enabled
func (h *SmsHandler) Status(c *gin.Context) {
resp.SUCCESS(c, statusVo{EnabledMsgService: h.App.SysConfig.EnabledMsgService, EnabledRegister: h.App.SysConfig.EnabledRegister})
resp.SUCCESS(c, statusVo{EnabledMsgService: h.App.SysConfig.EnabledMsg, EnabledRegister: h.App.SysConfig.EnabledRegister})
}

View File

@ -63,7 +63,7 @@ func (h *UserHandler) Register(c *gin.Context) {
// 检查验证码
key := CodeStorePrefix + data.Mobile
if h.App.SysConfig.EnabledMsgService {
if h.App.SysConfig.EnabledMsg {
var code int
err := h.leveldb.Get(key, &code)
if err != nil || code != data.Code {
@ -113,7 +113,7 @@ func (h *UserHandler) Register(c *gin.Context) {
return
}
if h.App.SysConfig.EnabledMsgService {
if h.App.SysConfig.EnabledMsg {
_ = h.leveldb.Delete(key) // 注册成功,删除短信验证码
}
resp.SUCCESS(c, user)

View File

@ -10,6 +10,7 @@ import (
"chatplus/service/fun"
"chatplus/service/mj"
"chatplus/service/oss"
"chatplus/service/sd"
"chatplus/service/wx"
"chatplus/store"
"context"
@ -121,6 +122,7 @@ func main() {
fx.Provide(handler.NewCaptchaHandler),
fx.Provide(handler.NewMidJourneyHandler),
fx.Provide(handler.NewChatModelHandler),
fx.Provide(handler.NewSdJobHandler),
fx.Provide(admin.NewConfigHandler),
fx.Provide(admin.NewAdminHandler),
@ -167,6 +169,13 @@ func main() {
}
}),
// Stable Diffusion 机器人
fx.Provide(sd.NewService),
fx.Invoke(func(service *sd.Service) {
go func() {
service.Run()
}()
}),
// 注册路由
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
group := s.Engine.Group("/api/role/")
@ -220,6 +229,12 @@ func main() {
group.GET("jobs", h.JobList)
group.Any("client", h.Client)
}),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
group := s.Engine.Group("/api/sd")
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.Any("client", h.Client)
}),
// 管理后台控制器
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {

View File

@ -20,7 +20,7 @@ import (
const RunningJobKey = "MidJourney_Running_Job"
type Service struct {
client *Client
client *Client // MJ 客户端
taskQueue *store.RedisQueue
redis *redis.Client
db *gorm.DB
@ -128,7 +128,7 @@ func (s *Service) Notify(data CBReq) {
// 任务完成,将最终的图片下载下来
if data.Progress == 100 {
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL)
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
@ -169,7 +169,7 @@ func (s *Service) Notify(data CBReq) {
utils.ReplyMessage(wsClient, content)
}
// download image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL)
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
if err != nil {
logger.Error("error with download image: ", err)
if wsClient != nil && data.ReferenceId != "" {

View File

@ -63,8 +63,14 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (string, error) {
return fmt.Sprintf("https://%s.%s/%s", s.config.Bucket, s.config.Endpoint, objectKey), nil
}
func (s AliYunOss) PutImg(imageURL string) (string, error) {
imageData, err := utils.DownloadImage(imageURL, s.proxyURL)
func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
var imageData []byte
var err error
if useProxy {
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
} else {
imageData, err = utils.DownloadImage(imageURL, "")
}
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}

View File

@ -41,14 +41,18 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (string, error) {
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
}
func (s LocalStorage) PutImg(imageURL string) (string, error) {
func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
filename := filepath.Base(imageURL)
filePath, err := utils.GenUploadPath(s.config.BasePath, filename)
if err != nil {
return "", fmt.Errorf("error with generate image dir: %v", err)
}
err = utils.DownloadFile(imageURL, filePath, s.proxyURL)
if useProxy {
err = utils.DownloadFile(imageURL, filePath, s.proxyURL)
} else {
err = utils.DownloadFile(imageURL, filePath, "")
}
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}

View File

@ -31,8 +31,14 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
}
func (s MiniOss) PutImg(imageURL string) (string, error) {
imageData, err := utils.DownloadImage(imageURL, s.proxyURL)
func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
var imageData []byte
var err error
if useProxy {
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
} else {
imageData, err = utils.DownloadImage(imageURL, "")
}
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}

View File

@ -72,8 +72,14 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (string, error) {
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
}
func (s QinNiuOss) PutImg(imageURL string) (string, error) {
imageData, err := utils.DownloadImage(imageURL, s.proxyURL)
func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
var imageData []byte
var err error
if useProxy {
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
} else {
imageData, err = utils.DownloadImage(imageURL, "")
}
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}

View File

@ -4,6 +4,6 @@ import "github.com/gin-gonic/gin"
type Uploader interface {
PutFile(ctx *gin.Context, name string) (string, error)
PutImg(imageURL string) (string, error)
PutImg(imageURL string, useProxy bool) (string, error)
Delete(fileURL string) error
}

View File

@ -1,169 +0,0 @@
package sd
import (
"chatplus/core/types"
"chatplus/utils"
"fmt"
"github.com/imroc/req/v3"
"io"
"time"
)
type Client struct {
httpClient *req.Client
config *types.StableDiffusionConfig
}
func NewSdClient(config *types.AppConfig) *Client {
return &Client{
config: &config.SdConfig,
httpClient: req.C(),
}
}
func (c *Client) Txt2Img(params types.SdTaskParams) error {
var data []interface{}
err := utils.JsonDecode(Text2ImgParamTemplate, &data)
if err != nil {
return err
}
data[ParamKeys["task_id"]] = params.TaskId
data[ParamKeys["prompt"]] = params.Prompt
data[ParamKeys["negative_prompt"]] = params.NegativePrompt
data[ParamKeys["steps"]] = params.Steps
data[ParamKeys["sampler"]] = params.Sampler
data[ParamKeys["face_fix"]] = params.FaceFix
data[ParamKeys["cfg_scale"]] = params.CfgScale
data[ParamKeys["seed"]] = params.Seed
data[ParamKeys["height"]] = params.Height
data[ParamKeys["width"]] = params.Width
data[ParamKeys["hd_fix"]] = params.HdFix
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
data[ParamKeys["hd_scale"]] = params.HdScale
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
data[ParamKeys["hd_sample_num"]] = params.HdSampleNum
task := TaskInfo{
TaskId: params.TaskId,
Data: data,
EventData: nil,
FnIndex: 494,
SessionHash: "ycaxgzm9ah",
}
go func() {
c.runTask(task, c.httpClient)
}()
return nil
}
func (c *Client) runTask(taskInfo TaskInfo, client *req.Client) {
body := map[string]any{
"data": taskInfo.Data,
"event_data": taskInfo.EventData,
"fn_index": taskInfo.FnIndex,
"session_hash": taskInfo.SessionHash,
}
var result = make(chan CBReq)
go func() {
var res struct {
Data []interface{} `json:"data"`
IsGenerating bool `json:"is_generating"`
Duration float64 `json:"duration"`
AverageDuration float64 `json:"average_duration"`
}
var cbReq = CBReq{TaskId: taskInfo.TaskId}
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(c.config.ApiURL + "/run/predict")
if err != nil {
cbReq.Message = "error with send request: " + err.Error()
cbReq.Success = false
result <- cbReq
return
}
if response.IsErrorState() {
bytes, _ := io.ReadAll(response.Body)
cbReq.Message = "error http status code: " + string(bytes)
cbReq.Success = false
result <- cbReq
return
}
var images []struct {
Name string `json:"name"`
Data interface{} `json:"data"`
IsFile bool `json:"is_file"`
}
err = utils.ForceCovert(res.Data[0], &images)
if err != nil {
cbReq.Message = "error with decode image:" + err.Error()
cbReq.Success = false
result <- cbReq
return
}
var info map[string]any
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
if err != nil {
cbReq.Message = err.Error()
cbReq.Success = false
result <- cbReq
return
}
//for k, v := range info {
// fmt.Println(k, " => ", v)
//}
cbReq.ImageName = images[0].Name
cbReq.Seed = utils.InterfaceToString(info["seed"])
cbReq.Success = true
cbReq.Progress = 100
result <- cbReq
close(result)
}()
for {
select {
case value := <-result:
if value.Success {
logger.Infof("%s/file=%s", c.config.ApiURL, value.ImageName)
}
return
default:
var progressReq = map[string]any{
"id_task": taskInfo.TaskId,
"id_live_preview": 1,
}
var progressRes struct {
Active bool `json:"active"`
Queued bool `json:"queued"`
Completed bool `json:"completed"`
Progress float64 `json:"progress"`
Eta float64 `json:"eta"`
LivePreview string `json:"live_preview"`
IDLivePreview int `json:"id_live_preview"`
TextInfo interface{} `json:"textinfo"`
}
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(c.config.ApiURL + "/internal/progress")
var cbReq = CBReq{TaskId: taskInfo.TaskId, Success: true}
if err != nil { // TODO: 这里可以考虑设置失败重试次数
logger.Error(err)
return
}
if response.IsErrorState() {
bytes, _ := io.ReadAll(response.Body)
logger.Error(string(bytes))
return
}
cbReq.ImageData = progressRes.LivePreview
cbReq.Progress = int(progressRes.Progress * 100)
fmt.Println("Progress: ", progressRes.Progress)
fmt.Println("Image: ", progressRes.LivePreview)
time.Sleep(time.Second)
}
}
}

View File

@ -1,72 +0,0 @@
package sd
import (
"chatplus/core/types"
"chatplus/service/mj"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
"context"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
"time"
)
// SD 绘画服务
const RunningJobKey = "StableDiffusion_Running_Job"
type Service struct {
taskQueue *store.RedisQueue
redis *redis.Client
db *gorm.DB
Client *Client
}
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client) *Service {
return &Service{
redis: redisCli,
db: db,
Client: client,
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli),
}
}
func (s *Service) Run() {
logger.Info("Starting StableDiffusion job consumer.")
ctx := context.Background()
for {
_, err := s.redis.Get(ctx, RunningJobKey).Result()
if err == nil { // 队列串行执行
time.Sleep(time.Second * 3)
continue
}
var task types.SdTask
err = s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
logger.Infof("Consuming Task: %+v", task)
err = s.Client.Txt2Img(task.Params)
if err != nil {
logger.Error("绘画任务执行失败:", err)
if task.RetryCount <= 5 {
s.taskQueue.RPush(task)
}
task.RetryCount += 1
time.Sleep(time.Second * 3)
continue
}
// 更新任务的执行状态
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
// 锁定任务执行通道直到任务超时5分钟
s.redis.Set(ctx, mj.RunningJobKey, utils.JsonEncode(task), time.Minute*5)
}
}
func (s *Service) PushTask(task types.SdTask) {
logger.Infof("add a new MidJourney Task: %+v", task)
s.taskQueue.RPush(task)
}

300
api/service/sd/service.go Normal file
View File

@ -0,0 +1,300 @@
package sd
import (
"chatplus/core/types"
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"context"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/imroc/req/v3"
"gorm.io/gorm"
"io"
"strconv"
"time"
)
// SD 绘画服务
const RunningJobKey = "StableDiffusion_Running_Job"
type Service struct {
httpClient *req.Client
config *types.StableDiffusionConfig
taskQueue *store.RedisQueue
redis *redis.Client
db *gorm.DB
uploadManager *oss.UploaderManager
Clients *types.LMap[string, *types.WsClient] // SD 绘画页面 websocket 连接池
}
func NewService(config *types.AppConfig, redisCli *redis.Client, db *gorm.DB, manager *oss.UploaderManager) *Service {
return &Service{
config: &config.SdConfig,
httpClient: req.C(),
redis: redisCli,
db: db,
uploadManager: manager,
Clients: types.NewLMap[string, *types.WsClient](),
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli),
}
}
func (s *Service) Run() {
logger.Info("Starting StableDiffusion job consumer.")
ctx := context.Background()
for {
_, err := s.redis.Get(ctx, RunningJobKey).Result()
if err == nil { // 队列串行执行
time.Sleep(time.Second * 3)
continue
}
var task types.SdTask
err = s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
logger.Infof("Consuming Task: %+v", task)
err = s.Txt2Img(task)
if err != nil {
logger.Error("绘画任务执行失败:", err)
if task.RetryCount <= 5 {
s.taskQueue.RPush(task)
}
task.RetryCount += 1
time.Sleep(time.Second * 3)
continue
}
// 更新任务的执行状态
s.db.Model(&model.SdJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
// 锁定任务执行通道直到任务超时5分钟
s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
}
}
// PushTask 推送任务到队列
func (s *Service) PushTask(task types.SdTask) {
logger.Infof("add a new MidJourney Task: %+v", task)
s.taskQueue.RPush(task)
}
// Txt2Img 文生图 API
func (s *Service) Txt2Img(task types.SdTask) error {
var data []interface{}
err := utils.JsonDecode(Text2ImgParamTemplate, &data)
if err != nil {
return err
}
params := task.Params
data[ParamKeys["task_id"]] = params.TaskId
data[ParamKeys["prompt"]] = params.Prompt
data[ParamKeys["negative_prompt"]] = params.NegativePrompt
data[ParamKeys["steps"]] = params.Steps
data[ParamKeys["sampler"]] = params.Sampler
data[ParamKeys["face_fix"]] = params.FaceFix
data[ParamKeys["cfg_scale"]] = params.CfgScale
data[ParamKeys["seed"]] = params.Seed
data[ParamKeys["height"]] = params.Height
data[ParamKeys["width"]] = params.Width
data[ParamKeys["hd_fix"]] = params.HdFix
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
data[ParamKeys["hd_scale"]] = params.HdScale
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
data[ParamKeys["hd_sample_num"]] = params.HdSteps
go func() {
s.runTask(TaskInfo{
SessionId: task.SessionId,
JobId: task.Id,
TaskId: params.TaskId,
Data: data,
EventData: nil,
FnIndex: 405,
SessionHash: "ycaxgzm9ah",
}, s.httpClient)
}()
return nil
}
// 执行任务
func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) {
body := map[string]any{
"data": taskInfo.Data,
"event_data": taskInfo.EventData,
"fn_index": taskInfo.FnIndex,
"session_hash": taskInfo.SessionHash,
}
logger.Debug(utils.JsonEncode(body))
var result = make(chan CBReq)
go func() {
var res struct {
Data []interface{} `json:"data"`
IsGenerating bool `json:"is_generating"`
Duration float64 `json:"duration"`
AverageDuration float64 `json:"average_duration"`
}
var cbReq = CBReq{TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(s.config.ApiURL + "/run/predict")
if err != nil {
cbReq.Message = "error with send request: " + err.Error()
cbReq.Success = false
result <- cbReq
return
}
if response.IsErrorState() {
bytes, _ := io.ReadAll(response.Body)
cbReq.Message = "error http status code: " + string(bytes)
cbReq.Success = false
result <- cbReq
return
}
var images []struct {
Name string `json:"name"`
Data interface{} `json:"data"`
IsFile bool `json:"is_file"`
}
err = utils.ForceCovert(res.Data[0], &images)
if err != nil {
cbReq.Message = "error with decode image:" + err.Error()
cbReq.Success = false
result <- cbReq
return
}
var info map[string]any
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
if err != nil {
cbReq.Message = err.Error()
cbReq.Success = false
result <- cbReq
return
}
//for k, v := range info {
// fmt.Println(k, " => ", v)
//}
cbReq.ImageName = images[0].Name
seed, _ := strconv.ParseInt(utils.InterfaceToString(info["seed"]), 10, 64)
cbReq.Seed = seed
cbReq.Success = true
cbReq.Progress = 100
result <- cbReq
close(result)
}()
for {
select {
case value := <-result:
s.callback(value)
return
default:
var progressReq = map[string]any{
"id_task": taskInfo.TaskId,
"id_live_preview": 1,
}
var progressRes struct {
Active bool `json:"active"`
Queued bool `json:"queued"`
Completed bool `json:"completed"`
Progress float64 `json:"progress"`
Eta float64 `json:"eta"`
LivePreview string `json:"live_preview"`
IDLivePreview int `json:"id_live_preview"`
TextInfo interface{} `json:"textinfo"`
}
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(s.config.ApiURL + "/internal/progress")
var cbReq = CBReq{TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
if err != nil { // TODO: 这里可以考虑设置失败重试次数
logger.Error(err)
return
}
if response.IsErrorState() {
bytes, _ := io.ReadAll(response.Body)
logger.Error(string(bytes))
return
}
cbReq.ImageData = progressRes.LivePreview
cbReq.Progress = int(progressRes.Progress * 100)
s.callback(cbReq)
time.Sleep(time.Second)
}
}
}
func (s *Service) callback(data CBReq) {
client := s.Clients.Get(data.SessionId)
if data.Success { // 任务成功
var job model.SdJob
res := s.db.Where("id = ?", data.JobId).First(&job)
if res.Error != nil {
logger.Warn("非法任务:", res.Error)
return
}
// 更新任务进度
job.Progress = data.Progress
// 更新任务 seed
var params types.SdTaskParams
err := utils.JsonDecode(job.Params, &params)
if err != nil {
logger.Error("任务解析失败:", err)
return
}
params.Seed = data.Seed
if data.ImageName != "" { // 下载图片
imageURL := fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(imageURL, false)
if err != nil {
logger.Error("error with download img: ", err.Error())
return
}
job.ImgURL = imageURL
}
res = s.db.Updates(&job)
if res.Error != nil {
logger.Error("error with update job: ", res.Error)
return
}
var jobVo vo.SdJob
err = utils.CopyObject(job, &jobVo)
if err != nil {
logger.Error("error with copy object: ", err)
return
}
if data.Progress < 100 {
logger.Infof(data.ImageData)
jobVo.ImgURL = data.ImageData
}
// 推送任务到前端
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
} else { // 任务失败
logger.Error("任务执行失败:", data.Message)
// 删除任务
s.db.Delete(&model.SdJob{Id: uint(data.JobId)})
// 推送消息到前端
if client != nil {
utils.ReplyChunkMessage(client, vo.SdJob{
Id: uint(data.JobId),
Progress: -1,
Prompt: fmt.Sprintf("任务[%s]执行失败,已删除!", data.TaskId),
})
}
}
}

View File

@ -5,19 +5,23 @@ import logger2 "chatplus/logger"
var logger = logger2.GetLogger()
type TaskInfo struct {
TaskId string `json:"task_id"`
Data interface{} `json:"data"`
EventData interface{} `json:"event_data"`
FnIndex int `json:"fn_index"`
SessionHash string `json:"session_hash"`
SessionId string
JobId int
TaskId string
Data []interface{}
EventData interface{}
FnIndex int
SessionHash string
}
type CBReq struct {
SessionId string
JobId int
TaskId string
ImageName string
ImageData string
Progress int
Seed string
Seed int64
Success bool
Message string
}
@ -41,164 +45,170 @@ var ParamKeys = map[string]int{
}
const Text2ImgParamTemplate = `[
"",
"",
"task(p1lk3n41saygmr8)",
"a tiger sit on the window",
"",
[],
30,
"DPM++ SDE Karras",
20,
"Euler a",
false,
false,
1,
1,
7.5,
7,
-1,
-1,
0,
0,
0,
false,
512,
512,
true,
128,
128,
false,
0.7,
2,
"Latent",
10,
0,
0,
"Use same sampler",
"",
"",
0,
[],
"None",
false,
"MultiDiffusion",
false,
10,
1,
1,
64,
false,
true,
1024,
1024,
96,
96,
48,
4,
1,
"None",
2,
false,
10,
false,
false,
false,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
false,
true,
true,
false,
1536,
96,
false,
false,
"LoRA",
"None",
1,
1,
64,
false,
false,
false,
false,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
0.4,
0.4,
0.2,
0.2,
"",
"",
"Background",
0.2,
-1,
false,
3072,
192,
true,
true,
true,
false,
"LoRA",
"None",
1,
1,
"LoRA",
"None",
1,
1,
"LoRA",
"None",
1,
1,
"LoRA",
"None",
1,
1,
null,
"Refresh models",
null,
null,
null,
null,
false,
"",
0.5,
true,
false,
"",
"Lerp",
false,
"🔄",
false,
false,
false,
false,
false,
false,
false,
false,
false,
"positive",
@ -209,26 +219,26 @@ false,
"",
"Seed",
"",
[],
"Nothing",
"",
[],
"Nothing",
"",
[],
true,
false,
false,
false,
0,
null,
false,
null,
false,
null,
null,
false,
null,
null,
false,
50
50,
[],
"",
"",
""
]`

View File

@ -0,0 +1,187 @@
.page-sd {
background-color: #282c34;
}
.page-sd .inner {
display: flex;
/* 修改滚动条的颜色 */
/* 修改滚动条轨道的背景颜色 */
/* 修改滚动条的滑块颜色 */
/* 修改滚动条的滑块的悬停颜色 */
}
.page-sd .inner .sd-box {
margin: 10px;
background-color: #262626;
border: 1px solid #454545;
min-width: 300px;
max-width: 300px;
padding: 10px;
border-radius: 10px;
color: #fff;
font-size: 14px;
}
.page-sd .inner .sd-box h2 {
font-weight: bold;
font-size: 20px;
text-align: center;
color: #47fff1;
}
.page-sd .inner .sd-box ::-webkit-scrollbar {
width: 0;
height: 0;
background-color: transparent;
}
.page-sd .inner .sd-box .sd-params {
margin-top: 10px;
overflow: auto;
}
.page-sd .inner .sd-box .sd-params .param-line {
padding: 0 10px;
}
.page-sd .inner .sd-box .sd-params .param-line .el-icon {
position: relative;
top: 3px;
}
.page-sd .inner .sd-box .sd-params .param-line .el-input__suffix-inner .el-icon {
top: 0;
}
.page-sd .inner .sd-box .sd-params .param-line .grid-content,
.page-sd .inner .sd-box .sd-params .param-line .form-item-inner {
display: flex;
}
.page-sd .inner .sd-box .sd-params .param-line .grid-content .el-icon,
.page-sd .inner .sd-box .sd-params .param-line .form-item-inner .el-icon {
margin-left: 10px;
margin-top: 2px;
}
.page-sd .inner .sd-box .sd-params .param-line.pt {
padding-top: 5px;
padding-bottom: 5px;
}
.page-sd .inner .sd-box .submit-btn {
padding: 10px 15px 0 15px;
text-align: center;
}
.page-sd .inner .sd-box .submit-btn .el-button {
width: 100%;
}
.page-sd .inner .sd-box .submit-btn .el-button span {
color: #2d3a4b;
}
.page-sd .inner .el-form .el-form-item__label {
color: #fff;
}
.page-sd .inner ::-webkit-scrollbar {
width: 10px; /* 滚动条宽度 */
}
.page-sd .inner ::-webkit-scrollbar-track {
background-color: #282c34;
}
.page-sd .inner ::-webkit-scrollbar-thumb {
background-color: #444;
border-radius: 10px;
}
.page-sd .inner ::-webkit-scrollbar-thumb:hover {
background-color: #666;
}
.page-sd .inner .task-list-box {
width: 100%;
padding: 10px;
color: #fff;
overflow-x: hidden;
}
.page-sd .inner .task-list-box .running-job-list .job-item {
width: 100%;
padding: 2px;
background-color: #555;
}
.page-sd .inner .task-list-box .running-job-list .job-item .job-item-inner {
position: relative;
height: 100%;
overflow: hidden;
}
.page-sd .inner .task-list-box .running-job-list .job-item .job-item-inner .progress {
position: absolute;
width: 100%;
height: 100%;
top: 0;
left: 0;
display: flex;
justify-content: center;
align-items: center;
}
.page-sd .inner .task-list-box .running-job-list .job-item .job-item-inner .progress span {
font-size: 20px;
color: #fff;
}
.page-sd .inner .task-list-box .finish-job-list .job-item {
width: 100%;
height: 100%;
}
.page-sd .inner .task-list-box .finish-job-list .job-item .opt .opt-line {
margin: 6px 0;
}
.page-sd .inner .task-list-box .finish-job-list .job-item .opt .opt-line ul {
display: flex;
flex-flow: row;
}
.page-sd .inner .task-list-box .finish-job-list .job-item .opt .opt-line ul li {
margin-right: 10px;
}
.page-sd .inner .task-list-box .finish-job-list .job-item .opt .opt-line ul li a {
padding: 3px 0;
width: 44px;
text-align: center;
border-radius: 5px;
display: block;
cursor: pointer;
background-color: #4e5058;
color: #fff;
}
.page-sd .inner .task-list-box .finish-job-list .job-item .opt .opt-line ul li a:hover {
background-color: #6d6f78;
}
.page-sd .inner .task-list-box .finish-job-list .job-item .opt .opt-line ul .show-prompt {
font-size: 20px;
cursor: pointer;
}
.page-sd .inner .task-list-box .el-image {
width: 100%;
height: 100%;
max-height: 240px;
}
.page-sd .inner .task-list-box .el-image img {
height: 240px;
}
.page-sd .inner .task-list-box .el-image .el-image-viewer__wrapper img {
width: auto;
height: auto;
}
.page-sd .inner .task-list-box .el-image .image-slot {
display: flex;
flex-flow: column;
justify-content: center;
align-items: center;
height: 100%;
min-height: 200px;
color: #fff;
}
.page-sd .inner .task-list-box .el-image .image-slot .iconfont {
font-size: 50px;
margin-bottom: 10px;
}
.page-sd .inner .task-list-box .el-image.upscale {
max-height: 304px;
}
.page-sd .inner .task-list-box .el-image.upscale img {
height: 304px;
}
.page-sd .inner .task-list-box .el-image.upscale .el-image-viewer__wrapper img {
width: auto;
height: auto;
}
.mj-list-item-prompt .el-icon {
margin-left: 10px;
cursor: pointer;
position: relative;
top: 2px;
}

View File

@ -0,0 +1,255 @@
.page-sd {
background-color: #282c34;
.inner {
display: flex;
.sd-box {
margin 10px
background-color #262626
border 1px solid #454545
min-width 300px
max-width 300px
padding 10px
border-radius 10px
color #ffffff;
font-size 14px
h2 {
font-weight: bold;
font-size 20px
text-align center
color #47fff1
}
//
::-webkit-scrollbar {
width: 0;
height: 0;
background-color: transparent;
}
.sd-params {
margin-top 10px
overflow auto
.param-line {
padding 0 10px
.el-icon {
position relative
top 3px
}
.el-input__suffix-inner {
.el-icon {
top 0
}
}
.grid-content
.form-item-inner {
display flex
.el-icon {
margin-left 10px
margin-top 2px
}
}
}
.param-line.pt {
padding-top 5px
padding-bottom 5px
}
}
.submit-btn {
padding 10px 15px 0 15px
text-align center
.el-button {
width 100%
span {
color #2D3A4B
}
}
}
}
.el-form {
.el-form-item__label {
color #ffffff
}
}
/* */
::-webkit-scrollbar {
width: 10px; /* */
}
/* */
::-webkit-scrollbar-track {
background-color: #282C34;
}
/* */
::-webkit-scrollbar-thumb {
background-color: #444444;
border-radius 10px
}
/* */
::-webkit-scrollbar-thumb:hover {
background-color: #666666;
}
.task-list-box {
width 100%
padding 10px
color #ffffff
overflow-x hidden
.running-job-list {
.job-item {
//border: 1px solid #454545;
width: 100%;
padding 2px
background-color #555555
.job-item-inner {
position relative
height 100%
overflow hidden
.progress {
position absolute
width 100%
height 100%
top 0
left 0
display flex
justify-content center
align-items center
span {
font-size 20px
color #ffffff
}
}
}
}
}
.finish-job-list {
.job-item {
width 100%
height 100%
.opt {
.opt-line {
margin 6px 0
ul {
display flex
flex-flow row
li {
margin-right 10px
a {
padding 3px 0
width 44px
text-align center
border-radius 5px
display block
cursor pointer
background-color #4E5058
color #ffffff
&:hover {
background-color #6D6F78
}
}
}
.show-prompt {
font-size 20px
cursor pointer
}
}
}
}
}
}
.el-image {
width 100%
height 100%
max-height 240px
img {
height 240px
}
.el-image-viewer__wrapper {
img {
width auto
height auto
}
}
.image-slot {
display flex
flex-flow column
justify-content center
align-items center
height 100%
min-height 200px
color #ffffff
.iconfont {
font-size 50px
margin-bottom 10px
}
}
}
.el-image.upscale {
max-height 304px
img {
height 304px
}
.el-image-viewer__wrapper {
img {
width auto
height auto
}
}
}
}
}
}
.mj-list-item-prompt {
.el-icon {
margin-left 10px
cursor pointer
position relative
top 2px
}
}

View File

@ -230,13 +230,13 @@
placement="top-start"
:title="getTaskType(scope.item.type)"
:width="240"
trigger="click"
trigger="hover"
>
<template #reference>
<div v-if="scope.item.progress > 0" class="job-item-inner">
<el-image :src="scope.item.img_url"
<el-image :src="scope.item['img_url']"
:zoom-rate="1.2"
:preview-src-list="[scope.item.img_url]"
:preview-src-list="[scope.item['img_url']]"
fit="cover"
:initial-index="0" loading="lazy">
<template #placeholder>
@ -289,7 +289,7 @@
<template #default="scope">
<div class="job-item">
<el-image
:src="scope.item.type === 'upscale'?scope.item.img_url+'?imageView2/1/w/240/h/300/q/75':scope.item.img_url+'?imageView2/1/w/240/h/240/q/75'"
:src="scope.item.type === 'upscale'?scope.item['img_url']+'?imageView2/1/w/240/h/300/q/75':scope.item['img_url']+'?imageView2/1/w/240/h/240/q/75'"
:class="scope.item.type === 'upscale'?'upscale':''"
:zoom-rate="1.2"
:preview-src-list="previewImgList"
@ -359,7 +359,6 @@
</div> <!-- end finish job list-->
</div>
<el-backtop :right="100" :bottom="100"/>
</div><!-- end task list box -->
</div>

View File

@ -1,41 +1,568 @@
<template>
<div class="page-sd" :style="{ height: winHeight + 'px' }">
<div class="page-sd">
<div class="inner">
<h1>Stable Diffusion 绘画中心</h1>
<h2>页面正在紧锣密鼓开发中敬请期待</h2>
<div class="sd-box">
<h2>MidJourney 创作中心</h2>
<div class="sd-params" :style="{ height: mjBoxHeight + 'px' }">
<el-form :model="params" label-width="80px" label-position="left">
<div class="param-line pt">
<span>图片比例</span>
<el-tooltip
effect="light"
content="生成图片的尺寸比例"
placement="right"
>
<el-icon>
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
<div class="param-line pt">
<el-row :gutter="10">
<el-col :span="8" v-for="item in rates" :key="item.value">
<div :class="item.value === params.rate?'grid-content active':'grid-content'"
@click="changeRate(item)">
<div :class="'shape '+item.css"></div>
<div class="text">{{ item.text }}</div>
</div>
</el-col>
</el-row>
</div>
<div class="param-line" style="padding-top: 10px">
<el-form-item label="采样方法">
<template #default>
<div class="form-item-inner">
<el-select v-model="params.sampler" size="small">
<el-option v-for="item in samplers" :label="item" :value="item"/>
</el-select>
<el-tooltip
effect="light"
content="出图效果比较好的一般是 Euler 和 DPM 系列算法"
raw-content
placement="right"
>
<el-icon>
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-form-item label="图片尺寸">
<template #default>
<div class="form-item-inner">
<el-row :gutter="20">
<el-col :span="12">
<el-input v-model.number="params.width" size="small" placeholder="图片宽度"/>
</el-col>
<el-col :span="12">
<el-input v-model.number="params.height" size="small" placeholder="图片高度"/>
</el-col>
</el-row>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-form-item label="迭代步数">
<template #default>
<div class="form-item-inner">
<el-input v-model.number="params.steps" size="small"/>
<el-tooltip
effect="light"
content="值越大则代表细节越多,同时也意味着出图速度越慢"
raw-content
placement="right"
>
<el-icon>
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-form-item label="引导系数">
<template #default>
<div class="form-item-inner">
<el-input v-model.number="params.cfg_scale" size="small"/>
<el-tooltip
effect="light"
content="提示词引导系数,图像在多大程度上服从提示词<br/> 较低值会产生更有创意的结果"
raw-content
placement="right"
>
<el-icon>
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-form-item label="引导系数">
<template #default>
<div class="form-item-inner">
<el-input v-model.number="params.seed" size="small"/>
<el-tooltip
effect="light"
content="随机数种子,相同的种子会得到相同的结果<br/> 设置为 -1 则每次随机生成种子"
raw-content
placement="right"
>
<el-icon>
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-form-item label="面部修复">
<template #default>
<div class="form-item-inner">
<el-switch v-model="params.face_fix" style="--el-switch-on-color: #47fff1;"/>
<el-tooltip
effect="light"
content="仅对绘制人物图像有效果。"
raw-content
placement="right"
>
<el-icon style="margin-top: 6px">
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-form-item label="高清修复">
<template #default>
<div class="form-item-inner">
<el-switch v-model="params.hd_fix" style="--el-switch-on-color: #47fff1;"/>
<el-tooltip
effect="light"
content="先以较小的分辨率生成图像,接着方法图像<br />然后在不更改构图的情况下再修改细节"
raw-content
placement="right"
>
<el-icon style="margin-top: 6px">
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-form-item label="重绘幅度">
<template #default>
<div class="form-item-inner">
<el-slider v-model.number="params.hd_redraw_rate" :max="1" :step="0.1"
style="width: 180px;--el-slider-main-bg-color:#47fff1"/>
<el-tooltip
effect="light"
content="决定算法对图像内容的影响程度<br />较大的值将得到越有创意的图像"
raw-content
placement="right"
>
<el-icon style="margin-top: 6px">
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-form-item label="放大算法">
<template #default>
<div class="form-item-inner">
<el-input v-model.number="params.hd_scale_alg" size="small"/>
<el-tooltip
effect="light"
content="随机数种子,相同的种子会得到相同的结果<br/> 设置为 -1 则每次随机生成种子"
raw-content
placement="right"
>
<el-icon>
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-form-item label="放大倍数">
<template #default>
<div class="form-item-inner">
<el-input v-model.number="params.hd_scale" size="small"/>
<el-tooltip
effect="light"
content="随机数种子,相同的种子会得到相同的结果<br/> 设置为 -1 则每次随机生成种子"
raw-content
placement="right"
>
<el-icon>
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-form-item label="迭代步数">
<template #default>
<div class="form-item-inner">
<el-input v-model.number="params.hd_scale" size="small"/>
<el-tooltip
effect="light"
content="放大迭代步数,相同的种子会得到相同的结果<br/> 设置为 -1 则每次随机生成种子"
raw-content
placement="right"
>
<el-icon>
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
</template>
</el-form-item>
</div>
<div class="param-line">
<el-input
v-model="params.prompt"
:autosize="{ minRows: 4, maxRows: 6 }"
type="textarea"
ref="promptRef"
placeholder="正向提示词例如A chinese girl walking in the middle of a cobblestone street"
/>
</div>
<div class="param-line pt">
<span>图片比例</span>
<el-tooltip
effect="light"
content="不希望出现的元素,下面给了默认的起手式"
placement="right"
>
<el-icon>
<InfoFilled/>
</el-icon>
</el-tooltip>
</div>
<div class="param-line">
<el-input
v-model="params.negative_prompt"
:autosize="{ minRows: 4, maxRows: 6 }"
type="textarea"
placeholder="反向提示词"
/>
</div>
<div class="param-line pt">
<el-form-item label="剩余次数">
<template #default>
<el-tag type="info">{{ imgCalls }}</el-tag>
</template>
</el-form-item>
</div>
</el-form>
</div>
<div class="submit-btn">
<el-button color="#47fff1" :dark="false" round @click="generate">立即生成</el-button>
</div>
</div>
<div class="task-list-box">
<div class="task-list-inner" :style="{ height: listBoxHeight + 'px' }">
<h2>任务列表</h2>
<div class="running-job-list">
<ItemList :items="runningJobs" v-if="runningJobs.length > 0">
<template #default="scope">
<div class="job-item">
<el-popover
placement="top-start"
title="绘画提示词"
:width="240"
trigger="hover"
>
<template #reference>
<div v-if="scope.item.progress > 0" class="job-item-inner">
<el-image :src="scope.item['img_url']"
fit="cover"
loading="lazy">
<template #placeholder>
<div class="image-slot">
正在加载图片
</div>
</template>
<template #error>
<div class="image-slot">
<el-icon v-if="scope.item['img_url'] !== ''">
<Picture/>
</el-icon>
</div>
</template>
</el-image>
<div class="progress">
<el-progress type="circle" :percentage="scope.item.progress" :width="100" color="#47fff1"/>
</div>
</div>
<el-image fit="cover" v-else>
<template #error>
<div class="image-slot">
<i class="iconfont icon-quick-start"></i>
<span>任务正在排队中</span>
</div>
</template>
</el-image>
</template>
<template #default>
<div class="mj-list-item-prompt">
<span>{{ scope.item.prompt }}</span>
<el-icon class="copy-prompt" :data-clipboard-text="scope.item.prompt">
<DocumentCopy/>
</el-icon>
</div>
</template>
</el-popover>
</div>
</template>
</ItemList>
<el-empty :image-size="100" v-else/>
</div>
<h2>创作记录</h2>
<div class="finish-job-list">
<ItemList :items="finishedJobs" v-if="finishedJobs.length > 0">
<template #default="scope">
<div class="job-item">
<el-image
:src="scope.item['img_url']+'?imageView2/1/w/240/h/240/q/75'"
fit="cover"
loading="lazy">
<template #placeholder>
<div class="image-slot">
正在加载图片
</div>
</template>
<template #error>
<div class="image-slot">
<el-icon>
<Picture/>
</el-icon>
</div>
</template>
</el-image>
</div>
</template>
</ItemList>
</div> <!-- end finish job list-->
</div>
</div><!-- end task list box -->
</div>
</div>
</template>
<script setup>
import {ref} from "vue"
import {onMounted, ref} from "vue"
import {ChromeFilled, DeleteFilled, DocumentCopy, InfoFilled, Picture, Plus} from "@element-plus/icons-vue";
import Compressor from "compressorjs";
import {httpGet, httpPost} from "@/utils/http";
import {ElMessage} from "element-plus";
import ItemList from "@/components/ItemList.vue";
import Clipboard from "clipboard";
import {checkSession} from "@/action/session";
import {useRouter} from "vue-router";
import {getSessionId, getUserToken} from "@/store/session";
const winHeight = ref(window.innerHeight)
</script>
const listBoxHeight = ref(window.innerHeight - 40)
const mjBoxHeight = ref(window.innerHeight - 150)
<style lang="stylus" scoped>
.page-sd {
display: flex;
justify-content: center;
align-items center
background-color: #282c34;
window.onresize = () => {
listBoxHeight.value = window.innerHeight - 40
mjBoxHeight.value = window.innerHeight - 150
}
const rates = [
{css: "horizontal", value: "768x512", text: "横图"},
{css: "square", value: "512x512", text: "方图"},
{css: "vertical", value: "512x768", text: "竖图"},
]
const samplers = ["Euler a", "Euler", "DPM2 a Karras", "DPM++ 2S a Karras", "DPM++ 2M Karras", "DPM++ SDE Karras", "DPM2", "DPM2 a", "DPM++ 2S a", "DPM++ 2M", "DPM++ SDE", "DPM fast", "DPM adaptive",
"LMS Karras", "DPM2 Karras", "DDIM", "PLMS", "UniPC", "LMS", "Heun",]
const params = ref({
rate: rates[1].value,
width: 256,
height: 256,
sampler: samplers[0],
seed: -1,
steps: 20,
cfg_scale: 7,
face_fix: false,
hd_fix: false,
hd_redraw_rate: 0.3,
hd_scale: 2,
hd_scale_alg: "ESRGAN_4x",
hd_steps: 10,
prompt: "a tiger sit on the window",
negative_prompt: "nsfw, paintings, cartoon, anime, sketches, low quality,easynegative,ng_deepnegative _v1 75t,(worst quality:2),(low quality:2),(normalquality:2),lowres,bad anatomy,bad hands,normal quality,((monochrome)),((grayscale)),((watermark))",
})
.inner {
text-align center
const runningJobs = ref([])
const finishedJobs = ref([])
const previewImgList = ref([])
const router = useRouter()
h1 {
color: #202020;
font-size: 80px;
font-weight: bold;
letter-spacing: 0.1em;
text-shadow: -1px -1px 1px #111111, 2px 2px 1px #363636;
}
const socket = ref(null)
const imgCalls = ref(0)
h2 {
color #ffffff;
font-weight: bold;
const connect = () => {
let host = process.env.VUE_APP_WS_HOST
if (host === '') {
if (location.protocol === 'https:') {
host = 'wss://' + location.host;
} else {
host = 'ws://' + location.host;
}
}
const _socket = new WebSocket(host + `/api/sd/client?session_id=${getSessionId()}&token=${getUserToken()}`);
_socket.addEventListener('open', () => {
socket.value = _socket;
});
_socket.addEventListener('message', event => {
if (event.data instanceof Blob) {
const reader = new FileReader();
reader.readAsText(event.data, "UTF-8");
reader.onload = () => {
const data = JSON.parse(String(reader.result));
let append = true
if (data.progress === 100) { //
for (let i = 0; i < finishedJobs.value.length; i++) {
if (finishedJobs.value[i].id === data.id) {
append = false
break
}
}
for (let i = 0; i < runningJobs.value.length; i++) {
if (runningJobs.value[i].id === data.id) {
runningJobs.value.splice(i, 1)
break
}
}
if (append) {
finishedJobs.value.unshift(data)
}
previewImgList.value.unshift(data["img_url"])
} else { //
for (let i = 0; i < runningJobs.value.length; i++) {
if (runningJobs.value[i].id === data.id) {
append = false
runningJobs.value[i] = data
break
}
}
if (append) {
runningJobs.value.push(data)
}
}
}
}
});
_socket.addEventListener('close', () => {
connect()
});
}
onMounted(() => {
checkSession().then(user => {
imgCalls.value = user['img_calls']
//
httpGet("/api/sd/jobs?status=0").then(res => {
runningJobs.value = res.data
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
//
httpGet("/api/sd/jobs?status=1").then(res => {
finishedJobs.value = res.data
previewImgList.value = []
for (let index in finishedJobs.value) {
previewImgList.value.push(finishedJobs.value[index]["img_url"])
}
}).catch(e => {
ElMessage.error("获取任务失败:" + e.message)
})
// socket
connect();
}).catch(() => {
router.push('/login')
});
const clipboard = new Clipboard('.copy-prompt');
clipboard.on('success', () => {
ElMessage.success({message: "复制成功!", duration: 500});
})
clipboard.on('error', () => {
ElMessage.error('复制失败!');
})
})
//
const changeRate = (item) => {
params.value.rate = item.value
}
//
const promptRef = ref(null)
const generate = () => {
if (params.value.prompt === '') {
promptRef.value.focus()
return ElMessage.error("请输入绘画提示词!")
}
params.value.session_id = getSessionId()
httpPost("/api/sd/image", params.value).then(() => {
ElMessage.success("绘画任务推送成功,请耐心等待任务执行...")
imgCalls.value -= 1
}).catch(e => {
ElMessage.error("任务推送失败:" + e.message)
})
}
</script>
<style lang="stylus">
@import "@/assets/css/image-sd.styl"
</style>

View File

@ -18,8 +18,8 @@
<el-form-item label="开放注册服务" prop="enabled_register">
<el-switch v-model="system['enabled_register']"/>
</el-form-item>
<el-form-item label="短信验证服务" prop="enabled_msg_service">
<el-switch v-model="system['enabled_msg_service']"/>
<el-form-item label="短信验证服务" prop="enabled_msg">
<el-switch v-model="system['enabled_msg']"/>
</el-form-item>
<el-form-item label="开放AI绘画" prop="enabled_draw">
<el-switch v-model="system['enabled_draw']"/>