mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-10 03:03:43 +08:00
feat: stable diffusion page is ready
This commit is contained in:
@@ -157,7 +157,9 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
var tokenString string
|
||||
if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API
|
||||
tokenString = c.GetHeader(types.AdminAuthHeader)
|
||||
} else if c.Request.URL.Path == "/api/chat/new" || c.Request.URL.Path == "/api/mj/client" {
|
||||
} else if c.Request.URL.Path == "/api/chat/new" ||
|
||||
c.Request.URL.Path == "/api/mj/client" ||
|
||||
c.Request.URL.Path == "/api/sd/client" {
|
||||
tokenString = c.Query("token")
|
||||
} else {
|
||||
tokenString = c.GetHeader(types.UserAuthHeader)
|
||||
|
||||
@@ -101,13 +101,13 @@ type ModelAPIConfig struct {
|
||||
}
|
||||
|
||||
type SystemConfig struct {
|
||||
Title string `json:"title"`
|
||||
AdminTitle string `json:"admin_title"`
|
||||
Models []string `json:"models"`
|
||||
UserInitCalls int `json:"user_init_calls"` // 新用户注册默认总送多少次调用
|
||||
InitImgCalls int `json:"init_img_calls"`
|
||||
VipMonthCalls int `json:"vip_month_calls"` // 会员每个赠送的调用次数
|
||||
EnabledRegister bool `json:"enabled_register"`
|
||||
EnabledMsgService bool `json:"enabled_msg_service"`
|
||||
EnabledDraw bool `json:"enabled_draw"` // 启动 AI 绘画功能
|
||||
Title string `json:"title"`
|
||||
AdminTitle string `json:"admin_title"`
|
||||
Models []string `json:"models"`
|
||||
UserInitCalls int `json:"user_init_calls"` // 新用户注册默认总送多少次调用
|
||||
InitImgCalls int `json:"init_img_calls"`
|
||||
VipMonthCalls int `json:"vip_month_calls"` // 会员每个赠送的调用次数
|
||||
EnabledRegister bool `json:"enabled_register"`
|
||||
EnabledMsg bool `json:"enabled_msg"` // 启用短信验证码服务
|
||||
EnabledDraw bool `json:"enabled_draw"` // 启动 AI 绘画功能
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ type MjTask struct {
|
||||
}
|
||||
|
||||
type SdTask struct {
|
||||
Id int `json:"id"`
|
||||
Id int `json:"id"` // job 数据库ID
|
||||
SessionId string `json:"session_id"`
|
||||
Src TaskSrc `json:"src"`
|
||||
Type TaskType `json:"type"`
|
||||
@@ -52,18 +52,18 @@ type SdTask struct {
|
||||
|
||||
type SdTaskParams struct {
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
NegativePrompt string `json:"negative_prompt"`
|
||||
Steps int `json:"steps"`
|
||||
Sampler string `json:"sampler"`
|
||||
FaceFix bool `json:"face_fix"`
|
||||
CfgScale float32 `json:"cfg_scale"`
|
||||
Seed int64 `json:"seed"`
|
||||
Prompt string `json:"prompt"` // 提示词
|
||||
NegativePrompt string `json:"negative_prompt"` // 反向提示词
|
||||
Steps int `json:"steps"` // 迭代步数,默认20
|
||||
Sampler string `json:"sampler"` // 采样器
|
||||
FaceFix bool `json:"face_fix"` // 面部修复
|
||||
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
|
||||
Seed int64 `json:"seed"` // 随机数种子
|
||||
Height int `json:"height"`
|
||||
Width int `json:"width"`
|
||||
HdFix bool `json:"hd_fix"`
|
||||
HdRedrawRate float32 `json:"hd_redraw_rate"`
|
||||
HdScale int `json:"hd_scale"`
|
||||
HdScaleAlg string `json:"hd_scale_alg"`
|
||||
HdSampleNum int `json:"hd_sample_num"`
|
||||
HdFix bool `json:"hd_fix"` // 启用高清修复
|
||||
HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
|
||||
HdScale int `json:"hd_scale"` // 放大倍数
|
||||
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
|
||||
HdSteps int `json:"hd_steps"` // 高清修复迭代步数
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/mj"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
@@ -17,33 +16,25 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MidJourneyHandler struct {
|
||||
BaseHandler
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
mjService *mj.Service
|
||||
uploaderManager *oss.UploaderManager
|
||||
lock sync.Mutex
|
||||
clients *types.LMap[string, *types.WsClient]
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
mjService *mj.Service
|
||||
}
|
||||
|
||||
func NewMidJourneyHandler(
|
||||
app *core.AppServer,
|
||||
client *redis.Client,
|
||||
db *gorm.DB,
|
||||
manager *oss.UploaderManager,
|
||||
mjService *mj.Service) *MidJourneyHandler {
|
||||
h := MidJourneyHandler{
|
||||
redis: client,
|
||||
db: db,
|
||||
uploaderManager: manager,
|
||||
lock: sync.Mutex{},
|
||||
mjService: mjService,
|
||||
clients: types.NewLMap[string, *types.WsClient](),
|
||||
redis: client,
|
||||
db: db,
|
||||
mjService: mjService,
|
||||
}
|
||||
h.App = app
|
||||
return &h
|
||||
@@ -59,9 +50,7 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
|
||||
|
||||
sessionId := c.Query("session_id")
|
||||
client := types.NewWsClient(ws)
|
||||
// 删除旧的连接
|
||||
h.clients.Delete(sessionId)
|
||||
h.clients.Put(sessionId, client)
|
||||
h.mjService.Clients.Put(sessionId, client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
|
||||
}
|
||||
|
||||
@@ -156,7 +145,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
err := utils.CopyObject(job, &jobVo)
|
||||
if err == nil {
|
||||
// 推送任务到前端
|
||||
client := h.clients.Get(data.SessionId)
|
||||
client := h.mjService.Clients.Get(data.SessionId)
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, jobVo)
|
||||
}
|
||||
@@ -212,7 +201,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
err := utils.CopyObject(job, &jobVo)
|
||||
if err == nil {
|
||||
// 推送任务到前端
|
||||
client := h.clients.Get(data.SessionId)
|
||||
client := h.mjService.Clients.Get(data.SessionId)
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, jobVo)
|
||||
}
|
||||
@@ -283,7 +272,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
err := utils.CopyObject(job, &jobVo)
|
||||
if err == nil {
|
||||
// 推送任务到前端
|
||||
client := h.clients.Get(data.SessionId)
|
||||
client := h.mjService.Clients.Get(data.SessionId)
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, jobVo)
|
||||
}
|
||||
|
||||
@@ -1,316 +1,202 @@
|
||||
package handler
|
||||
|
||||
//
|
||||
//import (
|
||||
// "chatplus/core"
|
||||
// "chatplus/core/types"
|
||||
// "chatplus/service"
|
||||
// "chatplus/service/oss"
|
||||
// "chatplus/store/model"
|
||||
// "chatplus/store/vo"
|
||||
// "chatplus/utils"
|
||||
// "chatplus/utils/resp"
|
||||
// "encoding/base64"
|
||||
// "fmt"
|
||||
// "github.com/gin-gonic/gin"
|
||||
// "github.com/go-redis/redis/v8"
|
||||
// "github.com/gorilla/websocket"
|
||||
// "gorm.io/gorm"
|
||||
// "net/http"
|
||||
// "strings"
|
||||
// "sync"
|
||||
// "time"
|
||||
//)
|
||||
//
|
||||
//type SdJobHandler struct {
|
||||
// BaseHandler
|
||||
// redis *redis.Client
|
||||
// db *gorm.DB
|
||||
// mjService *service.MjService
|
||||
// uploaderManager *oss.UploaderManager
|
||||
// lock sync.Mutex
|
||||
// clients *types.LMap[string, *types.WsClient]
|
||||
//}
|
||||
//
|
||||
//func NewSdJobHandler(
|
||||
// app *core.AppServer,
|
||||
// client *redis.Client,
|
||||
// db *gorm.DB,
|
||||
// manager *oss.UploaderManager,
|
||||
// mjService *service.MjService) *MidJourneyHandler {
|
||||
// h := MidJourneyHandler{
|
||||
// redis: client,
|
||||
// db: db,
|
||||
// uploaderManager: manager,
|
||||
// lock: sync.Mutex{},
|
||||
// mjService: mjService,
|
||||
// clients: types.NewLMap[string, *types.WsClient](),
|
||||
// }
|
||||
// h.App = app
|
||||
// return &h
|
||||
//}
|
||||
//
|
||||
//// Client WebSocket 客户端,用于通知任务状态变更
|
||||
//func (h *SdJobHandler) Client(c *gin.Context) {
|
||||
// ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
// if err != nil {
|
||||
// logger.Error(err)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// sessionId := c.Query("session_id")
|
||||
// client := types.NewWsClient(ws)
|
||||
// // 删除旧的连接
|
||||
// h.clients.Delete(sessionId)
|
||||
// h.clients.Put(sessionId, client)
|
||||
// logger.Infof("New websocket connected, IP: %s", c.ClientIP())
|
||||
//}
|
||||
//
|
||||
//type sdNotifyData struct {
|
||||
// TaskId string
|
||||
// ImageName string
|
||||
// ImageData string
|
||||
// Progress int
|
||||
// Seed string
|
||||
// Success bool
|
||||
// Message string
|
||||
//}
|
||||
//
|
||||
//func (h *SdJobHandler) Notify(c *gin.Context) {
|
||||
// token := c.GetHeader("Authorization")
|
||||
// if token != h.App.Config.ExtConfig.Token {
|
||||
// resp.NotAuth(c)
|
||||
// return
|
||||
// }
|
||||
// var data sdNotifyData
|
||||
// if err := c.ShouldBindJSON(&data); err != nil || data.TaskId == "" {
|
||||
// resp.ERROR(c, types.InvalidArgs)
|
||||
// return
|
||||
// }
|
||||
// logger.Debugf("收到 MidJourney 回调请求:%+v", data)
|
||||
//
|
||||
// h.lock.Lock()
|
||||
// defer h.lock.Unlock()
|
||||
//
|
||||
// err, finished := h.notifyHandler(c, data)
|
||||
// if err != nil {
|
||||
// resp.ERROR(c, err.Error())
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// // 解除任务锁定
|
||||
// if finished && (data.Progress == 100) {
|
||||
// h.redis.Del(c, service.MjRunningJobKey)
|
||||
// }
|
||||
// resp.SUCCESS(c)
|
||||
//
|
||||
//}
|
||||
//
|
||||
//func (h *SdJobHandler) notifyHandler(c *gin.Context, data sdNotifyData) (error, bool) {
|
||||
// taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
|
||||
// if err != nil { // 过期任务,丢弃
|
||||
// logger.Warn("任务已过期:", err)
|
||||
// return nil, true
|
||||
// }
|
||||
//
|
||||
// var task types.SdTask
|
||||
// err = utils.JsonDecode(taskString, &task)
|
||||
// if err != nil { // 非标准任务,丢弃
|
||||
// logger.Warn("任务解析失败:", err)
|
||||
// return nil, false
|
||||
// }
|
||||
//
|
||||
// var job model.SdJob
|
||||
// res := h.db.Where("id = ?", task.Id).First(&job)
|
||||
// if res.Error != nil {
|
||||
// logger.Warn("非法任务:", res.Error)
|
||||
// return nil, false
|
||||
// }
|
||||
// job.Params = utils.JsonEncode(task.Params)
|
||||
// job.ReferenceId = data.ImageData
|
||||
// job.Progress = data.Progress
|
||||
// job.Prompt = data.Prompt
|
||||
// job.Hash = data.Image.Hash
|
||||
//
|
||||
// // 任务完成,将最终的图片下载下来
|
||||
// if data.Progress == 100 {
|
||||
// imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
|
||||
// if err != nil {
|
||||
// logger.Error("error with download img: ", err.Error())
|
||||
// return err, false
|
||||
// }
|
||||
// job.ImgURL = imgURL
|
||||
// } else {
|
||||
// // 临时图片直接保存,访问的时候使用代理进行转发
|
||||
// job.ImgURL = data.Image.URL
|
||||
// }
|
||||
// res = h.db.Updates(&job)
|
||||
// if res.Error != nil {
|
||||
// logger.Error("error with update job: ", res.Error)
|
||||
// return res.Error, false
|
||||
// }
|
||||
//
|
||||
// var jobVo vo.MidJourneyJob
|
||||
// err := utils.CopyObject(job, &jobVo)
|
||||
// if err == nil {
|
||||
// if data.Progress < 100 {
|
||||
// image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL)
|
||||
// if err == nil {
|
||||
// jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 推送任务到前端
|
||||
// client := h.clients.Get(task.SessionId)
|
||||
// if client != nil {
|
||||
// utils.ReplyChunkMessage(client, jobVo)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 更新用户剩余绘图次数
|
||||
// if data.Progress == 100 {
|
||||
// h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||
// }
|
||||
//
|
||||
// return nil, true
|
||||
//}
|
||||
//
|
||||
//func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
|
||||
// user, err := utils.GetLoginUser(c, h.db)
|
||||
// if err != nil {
|
||||
// resp.NotAuth(c)
|
||||
// return false
|
||||
// }
|
||||
//
|
||||
// if user.ImgCalls <= 0 {
|
||||
// resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
|
||||
// return false
|
||||
// }
|
||||
//
|
||||
// return true
|
||||
//
|
||||
//}
|
||||
//
|
||||
//// Image 创建一个绘画任务
|
||||
//func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
// var data struct {
|
||||
// SessionId string `json:"session_id"`
|
||||
// Prompt string `json:"prompt"`
|
||||
// Rate string `json:"rate"`
|
||||
// Model string `json:"model"`
|
||||
// Chaos int `json:"chaos"`
|
||||
// Raw bool `json:"raw"`
|
||||
// Seed int64 `json:"seed"`
|
||||
// Stylize int `json:"stylize"`
|
||||
// Img string `json:"img"`
|
||||
// Weight float32 `json:"weight"`
|
||||
// }
|
||||
// if err := c.ShouldBindJSON(&data); err != nil {
|
||||
// resp.ERROR(c, types.InvalidArgs)
|
||||
// return
|
||||
// }
|
||||
// if !h.checkLimits(c) {
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// var prompt = data.Prompt
|
||||
// if data.Rate != "" && !strings.Contains(prompt, "--ar") {
|
||||
// prompt += " --ar " + data.Rate
|
||||
// }
|
||||
// if data.Seed > 0 && !strings.Contains(prompt, "--seed") {
|
||||
// prompt += fmt.Sprintf(" --seed %d", data.Seed)
|
||||
// }
|
||||
// if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") {
|
||||
// prompt += fmt.Sprintf(" --s %d", data.Stylize)
|
||||
// }
|
||||
// if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
|
||||
// prompt += fmt.Sprintf(" --c %d", data.Chaos)
|
||||
// }
|
||||
// if data.Img != "" {
|
||||
// prompt = fmt.Sprintf("%s %s", data.Img, prompt)
|
||||
// if data.Weight > 0 {
|
||||
// prompt += fmt.Sprintf(" --iw %f", data.Weight)
|
||||
// }
|
||||
// }
|
||||
// if data.Raw {
|
||||
// prompt += " --style raw"
|
||||
// }
|
||||
// if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
|
||||
// prompt += data.Model
|
||||
// }
|
||||
//
|
||||
// idValue, _ := c.Get(types.LoginUserID)
|
||||
// userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
// job := model.MidJourneyJob{
|
||||
// Type: service.Image.String(),
|
||||
// UserId: userId,
|
||||
// Progress: 0,
|
||||
// Prompt: prompt,
|
||||
// CreatedAt: time.Now(),
|
||||
// }
|
||||
// if res := h.db.Create(&job); res.Error != nil {
|
||||
// resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// h.mjService.PushTask(service.MjTask{
|
||||
// Id: int(job.Id),
|
||||
// SessionId: data.SessionId,
|
||||
// Src: service.TaskSrcImg,
|
||||
// Type: service.Image,
|
||||
// Prompt: prompt,
|
||||
// UserId: userId,
|
||||
// })
|
||||
//
|
||||
// var jobVo vo.MidJourneyJob
|
||||
// err := utils.CopyObject(job, &jobVo)
|
||||
// if err == nil {
|
||||
// // 推送任务到前端
|
||||
// client := h.clients.Get(data.SessionId)
|
||||
// if client != nil {
|
||||
// utils.ReplyChunkMessage(client, jobVo)
|
||||
// }
|
||||
// }
|
||||
// resp.SUCCESS(c)
|
||||
//}
|
||||
//
|
||||
//// JobList 获取 MJ 任务列表
|
||||
//func (h *SdJobHandler) JobList(c *gin.Context) {
|
||||
// status := h.GetInt(c, "status", 0)
|
||||
// var items []model.MidJourneyJob
|
||||
// var res *gorm.DB
|
||||
// userId, _ := c.Get(types.LoginUserID)
|
||||
// if status == 1 {
|
||||
// res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items)
|
||||
// } else {
|
||||
// res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items)
|
||||
// }
|
||||
// if res.Error != nil {
|
||||
// resp.ERROR(c, types.NoData)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// var jobs = make([]vo.MidJourneyJob, 0)
|
||||
// for _, item := range items {
|
||||
// var job vo.MidJourneyJob
|
||||
// err := utils.CopyObject(item, &job)
|
||||
// if err != nil {
|
||||
// continue
|
||||
// }
|
||||
// if item.Progress < 100 {
|
||||
// // 30 分钟还没完成的任务直接删除
|
||||
// if time.Now().Sub(item.CreatedAt) > time.Minute*30 {
|
||||
// h.db.Delete(&item)
|
||||
// continue
|
||||
// }
|
||||
// if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
|
||||
// image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL)
|
||||
// if err == nil {
|
||||
// job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// jobs = append(jobs, job)
|
||||
// }
|
||||
// resp.SUCCESS(c, jobs)
|
||||
//}
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/sd"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SdJobHandler struct {
|
||||
BaseHandler
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
service *sd.Service
|
||||
}
|
||||
|
||||
func NewSdJobHandler(app *core.AppServer, redisCli *redis.Client, db *gorm.DB, service *sd.Service) *SdJobHandler {
|
||||
h := SdJobHandler{
|
||||
redis: redisCli,
|
||||
db: db,
|
||||
service: service,
|
||||
}
|
||||
h.App = app
|
||||
return &h
|
||||
}
|
||||
|
||||
// Client WebSocket 客户端,用于通知任务状态变更
|
||||
func (h *SdJobHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
sessionId := c.Query("session_id")
|
||||
client := types.NewWsClient(ws)
|
||||
// 删除旧的连接
|
||||
h.service.Clients.Put(sessionId, client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
|
||||
}
|
||||
|
||||
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
|
||||
user, err := utils.GetLoginUser(c, h.db)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return false
|
||||
}
|
||||
|
||||
if user.ImgCalls <= 0 {
|
||||
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
// Image 创建一个绘画任务
|
||||
func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
if !h.App.Config.SdConfig.Enabled {
|
||||
resp.ERROR(c, "Stable Diffusion service is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.checkLimits(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var data struct {
|
||||
SessionId string `json:"session_id"`
|
||||
types.SdTaskParams
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
if data.Width <= 0 {
|
||||
data.Width = 512
|
||||
}
|
||||
if data.Height <= 0 {
|
||||
data.Height = 512
|
||||
}
|
||||
if data.CfgScale <= 0 {
|
||||
data.CfgScale = 7
|
||||
}
|
||||
if data.Seed == 0 {
|
||||
data.Seed = -1
|
||||
}
|
||||
if data.Steps <= 0 {
|
||||
data.Steps = 20
|
||||
}
|
||||
if data.Sampler == "" {
|
||||
data.Sampler = "Euler a"
|
||||
}
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
params := types.SdTaskParams{
|
||||
TaskId: fmt.Sprintf("task(%s)", utils.RandString(15)),
|
||||
Prompt: data.Prompt,
|
||||
NegativePrompt: data.NegativePrompt,
|
||||
Steps: data.Steps,
|
||||
Sampler: data.Sampler,
|
||||
FaceFix: data.FaceFix,
|
||||
CfgScale: data.CfgScale,
|
||||
Seed: data.Seed,
|
||||
Height: data.Height,
|
||||
Width: data.Width,
|
||||
HdFix: data.HdFix,
|
||||
HdRedrawRate: data.HdRedrawRate,
|
||||
HdScale: data.HdScale,
|
||||
HdScaleAlg: data.HdScaleAlg,
|
||||
HdSteps: data.HdSteps,
|
||||
}
|
||||
job := model.SdJob{
|
||||
UserId: userId,
|
||||
Type: types.TaskImage.String(),
|
||||
TaskId: params.TaskId,
|
||||
Params: utils.JsonEncode(params),
|
||||
Prompt: data.Prompt,
|
||||
Progress: 0,
|
||||
Started: false,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
res := h.db.Create(&job)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "error with save job: "+res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.service.PushTask(types.SdTask{
|
||||
Id: int(job.Id),
|
||||
SessionId: data.SessionId,
|
||||
Src: types.TaskSrcImg,
|
||||
Type: types.TaskImage,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
UserId: userId,
|
||||
})
|
||||
var jobVo vo.SdJob
|
||||
err := utils.CopyObject(job, &jobVo)
|
||||
if err == nil {
|
||||
// 推送任务到前端
|
||||
client := h.service.Clients.Get(data.SessionId)
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, jobVo)
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// JobList 获取 MJ 任务列表
|
||||
func (h *SdJobHandler) JobList(c *gin.Context) {
|
||||
status := h.GetInt(c, "status", 0)
|
||||
var items []model.SdJob
|
||||
var res *gorm.DB
|
||||
userId, _ := c.Get(types.LoginUserID)
|
||||
if status == 1 {
|
||||
res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items)
|
||||
} else {
|
||||
res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items)
|
||||
}
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, types.NoData)
|
||||
return
|
||||
}
|
||||
|
||||
var jobs = make([]vo.SdJob, 0)
|
||||
for _, item := range items {
|
||||
var job vo.SdJob
|
||||
err := utils.CopyObject(item, &job)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if item.Progress < 100 {
|
||||
// 30 分钟还没完成的任务直接删除
|
||||
if time.Now().Sub(item.CreatedAt) > time.Minute*30 {
|
||||
h.db.Delete(&item)
|
||||
continue
|
||||
}
|
||||
if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
|
||||
image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL)
|
||||
if err == nil {
|
||||
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
}
|
||||
}
|
||||
}
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
resp.SUCCESS(c, jobs)
|
||||
}
|
||||
|
||||
@@ -66,5 +66,5 @@ type statusVo struct {
|
||||
|
||||
// Status check if the message service is enabled
|
||||
func (h *SmsHandler) Status(c *gin.Context) {
|
||||
resp.SUCCESS(c, statusVo{EnabledMsgService: h.App.SysConfig.EnabledMsgService, EnabledRegister: h.App.SysConfig.EnabledRegister})
|
||||
resp.SUCCESS(c, statusVo{EnabledMsgService: h.App.SysConfig.EnabledMsg, EnabledRegister: h.App.SysConfig.EnabledRegister})
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
|
||||
// 检查验证码
|
||||
key := CodeStorePrefix + data.Mobile
|
||||
if h.App.SysConfig.EnabledMsgService {
|
||||
if h.App.SysConfig.EnabledMsg {
|
||||
var code int
|
||||
err := h.leveldb.Get(key, &code)
|
||||
if err != nil || code != data.Code {
|
||||
@@ -113,7 +113,7 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.App.SysConfig.EnabledMsgService {
|
||||
if h.App.SysConfig.EnabledMsg {
|
||||
_ = h.leveldb.Delete(key) // 注册成功,删除短信验证码
|
||||
}
|
||||
resp.SUCCESS(c, user)
|
||||
|
||||
15
api/main.go
15
api/main.go
@@ -10,6 +10,7 @@ import (
|
||||
"chatplus/service/fun"
|
||||
"chatplus/service/mj"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/service/sd"
|
||||
"chatplus/service/wx"
|
||||
"chatplus/store"
|
||||
"context"
|
||||
@@ -121,6 +122,7 @@ func main() {
|
||||
fx.Provide(handler.NewCaptchaHandler),
|
||||
fx.Provide(handler.NewMidJourneyHandler),
|
||||
fx.Provide(handler.NewChatModelHandler),
|
||||
fx.Provide(handler.NewSdJobHandler),
|
||||
|
||||
fx.Provide(admin.NewConfigHandler),
|
||||
fx.Provide(admin.NewAdminHandler),
|
||||
@@ -167,6 +169,13 @@ func main() {
|
||||
}
|
||||
}),
|
||||
|
||||
// Stable Diffusion 机器人
|
||||
fx.Provide(sd.NewService),
|
||||
fx.Invoke(func(service *sd.Service) {
|
||||
go func() {
|
||||
service.Run()
|
||||
}()
|
||||
}),
|
||||
// 注册路由
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
|
||||
group := s.Engine.Group("/api/role/")
|
||||
@@ -220,6 +229,12 @@ func main() {
|
||||
group.GET("jobs", h.JobList)
|
||||
group.Any("client", h.Client)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
|
||||
group := s.Engine.Group("/api/sd")
|
||||
group.POST("image", h.Image)
|
||||
group.GET("jobs", h.JobList)
|
||||
group.Any("client", h.Client)
|
||||
}),
|
||||
|
||||
// 管理后台控制器
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
const RunningJobKey = "MidJourney_Running_Job"
|
||||
|
||||
type Service struct {
|
||||
client *Client
|
||||
client *Client // MJ 客户端
|
||||
taskQueue *store.RedisQueue
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
@@ -128,7 +128,7 @@ func (s *Service) Notify(data CBReq) {
|
||||
|
||||
// 任务完成,将最终的图片下载下来
|
||||
if data.Progress == 100 {
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL)
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
|
||||
if err != nil {
|
||||
logger.Error("error with download img: ", err.Error())
|
||||
return
|
||||
@@ -169,7 +169,7 @@ func (s *Service) Notify(data CBReq) {
|
||||
utils.ReplyMessage(wsClient, content)
|
||||
}
|
||||
// download image
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL)
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
|
||||
if err != nil {
|
||||
logger.Error("error with download image: ", err)
|
||||
if wsClient != nil && data.ReferenceId != "" {
|
||||
|
||||
@@ -63,8 +63,14 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (string, error) {
|
||||
return fmt.Sprintf("https://%s.%s/%s", s.config.Bucket, s.config.Endpoint, objectKey), nil
|
||||
}
|
||||
|
||||
func (s AliYunOss) PutImg(imageURL string) (string, error) {
|
||||
imageData, err := utils.DownloadImage(imageURL, s.proxyURL)
|
||||
func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
var imageData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
|
||||
} else {
|
||||
imageData, err = utils.DownloadImage(imageURL, "")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
|
||||
@@ -41,14 +41,18 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (string, error) {
|
||||
return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
|
||||
}
|
||||
|
||||
func (s LocalStorage) PutImg(imageURL string) (string, error) {
|
||||
func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
filename := filepath.Base(imageURL)
|
||||
filePath, err := utils.GenUploadPath(s.config.BasePath, filename)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with generate image dir: %v", err)
|
||||
}
|
||||
|
||||
err = utils.DownloadFile(imageURL, filePath, s.proxyURL)
|
||||
if useProxy {
|
||||
err = utils.DownloadFile(imageURL, filePath, s.proxyURL)
|
||||
} else {
|
||||
err = utils.DownloadFile(imageURL, filePath, "")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
|
||||
@@ -31,8 +31,14 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
|
||||
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
|
||||
}
|
||||
|
||||
func (s MiniOss) PutImg(imageURL string) (string, error) {
|
||||
imageData, err := utils.DownloadImage(imageURL, s.proxyURL)
|
||||
func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
var imageData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
|
||||
} else {
|
||||
imageData, err = utils.DownloadImage(imageURL, "")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
|
||||
@@ -72,8 +72,14 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (string, error) {
|
||||
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
||||
}
|
||||
|
||||
func (s QinNiuOss) PutImg(imageURL string) (string, error) {
|
||||
imageData, err := utils.DownloadImage(imageURL, s.proxyURL)
|
||||
func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
|
||||
var imageData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
imageData, err = utils.DownloadImage(imageURL, s.proxyURL)
|
||||
} else {
|
||||
imageData, err = utils.DownloadImage(imageURL, "")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,6 @@ import "github.com/gin-gonic/gin"
|
||||
|
||||
type Uploader interface {
|
||||
PutFile(ctx *gin.Context, name string) (string, error)
|
||||
PutImg(imageURL string) (string, error)
|
||||
PutImg(imageURL string, useProxy bool) (string, error)
|
||||
Delete(fileURL string) error
|
||||
}
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
package sd
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"github.com/imroc/req/v3"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
httpClient *req.Client
|
||||
config *types.StableDiffusionConfig
|
||||
}
|
||||
|
||||
func NewSdClient(config *types.AppConfig) *Client {
|
||||
return &Client{
|
||||
config: &config.SdConfig,
|
||||
httpClient: req.C(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Txt2Img(params types.SdTaskParams) error {
|
||||
var data []interface{}
|
||||
err := utils.JsonDecode(Text2ImgParamTemplate, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data[ParamKeys["task_id"]] = params.TaskId
|
||||
data[ParamKeys["prompt"]] = params.Prompt
|
||||
data[ParamKeys["negative_prompt"]] = params.NegativePrompt
|
||||
data[ParamKeys["steps"]] = params.Steps
|
||||
data[ParamKeys["sampler"]] = params.Sampler
|
||||
data[ParamKeys["face_fix"]] = params.FaceFix
|
||||
data[ParamKeys["cfg_scale"]] = params.CfgScale
|
||||
data[ParamKeys["seed"]] = params.Seed
|
||||
data[ParamKeys["height"]] = params.Height
|
||||
data[ParamKeys["width"]] = params.Width
|
||||
data[ParamKeys["hd_fix"]] = params.HdFix
|
||||
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
|
||||
data[ParamKeys["hd_scale"]] = params.HdScale
|
||||
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
|
||||
data[ParamKeys["hd_sample_num"]] = params.HdSampleNum
|
||||
task := TaskInfo{
|
||||
TaskId: params.TaskId,
|
||||
Data: data,
|
||||
EventData: nil,
|
||||
FnIndex: 494,
|
||||
SessionHash: "ycaxgzm9ah",
|
||||
}
|
||||
|
||||
go func() {
|
||||
c.runTask(task, c.httpClient)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) runTask(taskInfo TaskInfo, client *req.Client) {
|
||||
body := map[string]any{
|
||||
"data": taskInfo.Data,
|
||||
"event_data": taskInfo.EventData,
|
||||
"fn_index": taskInfo.FnIndex,
|
||||
"session_hash": taskInfo.SessionHash,
|
||||
}
|
||||
|
||||
var result = make(chan CBReq)
|
||||
go func() {
|
||||
var res struct {
|
||||
Data []interface{} `json:"data"`
|
||||
IsGenerating bool `json:"is_generating"`
|
||||
Duration float64 `json:"duration"`
|
||||
AverageDuration float64 `json:"average_duration"`
|
||||
}
|
||||
var cbReq = CBReq{TaskId: taskInfo.TaskId}
|
||||
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(c.config.ApiURL + "/run/predict")
|
||||
if err != nil {
|
||||
cbReq.Message = "error with send request: " + err.Error()
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
return
|
||||
}
|
||||
|
||||
if response.IsErrorState() {
|
||||
bytes, _ := io.ReadAll(response.Body)
|
||||
cbReq.Message = "error http status code: " + string(bytes)
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
return
|
||||
}
|
||||
|
||||
var images []struct {
|
||||
Name string `json:"name"`
|
||||
Data interface{} `json:"data"`
|
||||
IsFile bool `json:"is_file"`
|
||||
}
|
||||
err = utils.ForceCovert(res.Data[0], &images)
|
||||
if err != nil {
|
||||
cbReq.Message = "error with decode image:" + err.Error()
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
return
|
||||
}
|
||||
|
||||
var info map[string]any
|
||||
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
|
||||
if err != nil {
|
||||
cbReq.Message = err.Error()
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
return
|
||||
}
|
||||
|
||||
//for k, v := range info {
|
||||
// fmt.Println(k, " => ", v)
|
||||
//}
|
||||
cbReq.ImageName = images[0].Name
|
||||
cbReq.Seed = utils.InterfaceToString(info["seed"])
|
||||
cbReq.Success = true
|
||||
cbReq.Progress = 100
|
||||
result <- cbReq
|
||||
close(result)
|
||||
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case value := <-result:
|
||||
if value.Success {
|
||||
logger.Infof("%s/file=%s", c.config.ApiURL, value.ImageName)
|
||||
}
|
||||
return
|
||||
default:
|
||||
var progressReq = map[string]any{
|
||||
"id_task": taskInfo.TaskId,
|
||||
"id_live_preview": 1,
|
||||
}
|
||||
|
||||
var progressRes struct {
|
||||
Active bool `json:"active"`
|
||||
Queued bool `json:"queued"`
|
||||
Completed bool `json:"completed"`
|
||||
Progress float64 `json:"progress"`
|
||||
Eta float64 `json:"eta"`
|
||||
LivePreview string `json:"live_preview"`
|
||||
IDLivePreview int `json:"id_live_preview"`
|
||||
TextInfo interface{} `json:"textinfo"`
|
||||
}
|
||||
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(c.config.ApiURL + "/internal/progress")
|
||||
var cbReq = CBReq{TaskId: taskInfo.TaskId, Success: true}
|
||||
if err != nil { // TODO: 这里可以考虑设置失败重试次数
|
||||
logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if response.IsErrorState() {
|
||||
bytes, _ := io.ReadAll(response.Body)
|
||||
logger.Error(string(bytes))
|
||||
return
|
||||
}
|
||||
|
||||
cbReq.ImageData = progressRes.LivePreview
|
||||
cbReq.Progress = int(progressRes.Progress * 100)
|
||||
fmt.Println("Progress: ", progressRes.Progress)
|
||||
fmt.Println("Image: ", progressRes.LivePreview)
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
package sd
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/mj"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SD 绘画服务
|
||||
|
||||
const RunningJobKey = "StableDiffusion_Running_Job"
|
||||
|
||||
type Service struct {
|
||||
taskQueue *store.RedisQueue
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
Client *Client
|
||||
}
|
||||
|
||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client) *Service {
|
||||
return &Service{
|
||||
redis: redisCli,
|
||||
db: db,
|
||||
Client: client,
|
||||
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
logger.Info("Starting StableDiffusion job consumer.")
|
||||
ctx := context.Background()
|
||||
for {
|
||||
_, err := s.redis.Get(ctx, RunningJobKey).Result()
|
||||
if err == nil { // 队列串行执行
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
var task types.SdTask
|
||||
err = s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.Infof("Consuming Task: %+v", task)
|
||||
err = s.Client.Txt2Img(task.Params)
|
||||
if err != nil {
|
||||
logger.Error("绘画任务执行失败:", err)
|
||||
if task.RetryCount <= 5 {
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
task.RetryCount += 1
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新任务的执行状态
|
||||
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
|
||||
// 锁定任务执行通道,直到任务超时(5分钟)
|
||||
s.redis.Set(ctx, mj.RunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) PushTask(task types.SdTask) {
|
||||
logger.Infof("add a new MidJourney Task: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
300
api/service/sd/service.go
Normal file
300
api/service/sd/service.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package sd
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SD 绘画服务
|
||||
|
||||
const RunningJobKey = "StableDiffusion_Running_Job"
|
||||
|
||||
type Service struct {
|
||||
httpClient *req.Client
|
||||
config *types.StableDiffusionConfig
|
||||
taskQueue *store.RedisQueue
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
Clients *types.LMap[string, *types.WsClient] // SD 绘画页面 websocket 连接池
|
||||
}
|
||||
|
||||
func NewService(config *types.AppConfig, redisCli *redis.Client, db *gorm.DB, manager *oss.UploaderManager) *Service {
|
||||
return &Service{
|
||||
config: &config.SdConfig,
|
||||
httpClient: req.C(),
|
||||
redis: redisCli,
|
||||
db: db,
|
||||
uploadManager: manager,
|
||||
Clients: types.NewLMap[string, *types.WsClient](),
|
||||
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
logger.Info("Starting StableDiffusion job consumer.")
|
||||
ctx := context.Background()
|
||||
for {
|
||||
_, err := s.redis.Get(ctx, RunningJobKey).Result()
|
||||
if err == nil { // 队列串行执行
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
var task types.SdTask
|
||||
err = s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.Infof("Consuming Task: %+v", task)
|
||||
err = s.Txt2Img(task)
|
||||
if err != nil {
|
||||
logger.Error("绘画任务执行失败:", err)
|
||||
if task.RetryCount <= 5 {
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
task.RetryCount += 1
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新任务的执行状态
|
||||
s.db.Model(&model.SdJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
|
||||
// 锁定任务执行通道,直到任务超时(5分钟)
|
||||
s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
||||
}
|
||||
}
|
||||
|
||||
// PushTask 推送任务到队列
|
||||
func (s *Service) PushTask(task types.SdTask) {
|
||||
logger.Infof("add a new MidJourney Task: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
// Txt2Img 文生图 API
|
||||
func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
var data []interface{}
|
||||
err := utils.JsonDecode(Text2ImgParamTemplate, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params := task.Params
|
||||
data[ParamKeys["task_id"]] = params.TaskId
|
||||
data[ParamKeys["prompt"]] = params.Prompt
|
||||
data[ParamKeys["negative_prompt"]] = params.NegativePrompt
|
||||
data[ParamKeys["steps"]] = params.Steps
|
||||
data[ParamKeys["sampler"]] = params.Sampler
|
||||
data[ParamKeys["face_fix"]] = params.FaceFix
|
||||
data[ParamKeys["cfg_scale"]] = params.CfgScale
|
||||
data[ParamKeys["seed"]] = params.Seed
|
||||
data[ParamKeys["height"]] = params.Height
|
||||
data[ParamKeys["width"]] = params.Width
|
||||
data[ParamKeys["hd_fix"]] = params.HdFix
|
||||
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
|
||||
data[ParamKeys["hd_scale"]] = params.HdScale
|
||||
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
|
||||
data[ParamKeys["hd_sample_num"]] = params.HdSteps
|
||||
|
||||
go func() {
|
||||
s.runTask(TaskInfo{
|
||||
SessionId: task.SessionId,
|
||||
JobId: task.Id,
|
||||
TaskId: params.TaskId,
|
||||
Data: data,
|
||||
EventData: nil,
|
||||
FnIndex: 405,
|
||||
SessionHash: "ycaxgzm9ah",
|
||||
}, s.httpClient)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// 执行任务
|
||||
func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) {
|
||||
body := map[string]any{
|
||||
"data": taskInfo.Data,
|
||||
"event_data": taskInfo.EventData,
|
||||
"fn_index": taskInfo.FnIndex,
|
||||
"session_hash": taskInfo.SessionHash,
|
||||
}
|
||||
logger.Debug(utils.JsonEncode(body))
|
||||
var result = make(chan CBReq)
|
||||
go func() {
|
||||
var res struct {
|
||||
Data []interface{} `json:"data"`
|
||||
IsGenerating bool `json:"is_generating"`
|
||||
Duration float64 `json:"duration"`
|
||||
AverageDuration float64 `json:"average_duration"`
|
||||
}
|
||||
var cbReq = CBReq{TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
|
||||
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(s.config.ApiURL + "/run/predict")
|
||||
if err != nil {
|
||||
cbReq.Message = "error with send request: " + err.Error()
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
return
|
||||
}
|
||||
|
||||
if response.IsErrorState() {
|
||||
bytes, _ := io.ReadAll(response.Body)
|
||||
cbReq.Message = "error http status code: " + string(bytes)
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
return
|
||||
}
|
||||
|
||||
var images []struct {
|
||||
Name string `json:"name"`
|
||||
Data interface{} `json:"data"`
|
||||
IsFile bool `json:"is_file"`
|
||||
}
|
||||
err = utils.ForceCovert(res.Data[0], &images)
|
||||
if err != nil {
|
||||
cbReq.Message = "error with decode image:" + err.Error()
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
return
|
||||
}
|
||||
|
||||
var info map[string]any
|
||||
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
|
||||
if err != nil {
|
||||
cbReq.Message = err.Error()
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
return
|
||||
}
|
||||
|
||||
//for k, v := range info {
|
||||
// fmt.Println(k, " => ", v)
|
||||
//}
|
||||
cbReq.ImageName = images[0].Name
|
||||
seed, _ := strconv.ParseInt(utils.InterfaceToString(info["seed"]), 10, 64)
|
||||
cbReq.Seed = seed
|
||||
cbReq.Success = true
|
||||
cbReq.Progress = 100
|
||||
result <- cbReq
|
||||
close(result)
|
||||
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case value := <-result:
|
||||
s.callback(value)
|
||||
return
|
||||
default:
|
||||
var progressReq = map[string]any{
|
||||
"id_task": taskInfo.TaskId,
|
||||
"id_live_preview": 1,
|
||||
}
|
||||
|
||||
var progressRes struct {
|
||||
Active bool `json:"active"`
|
||||
Queued bool `json:"queued"`
|
||||
Completed bool `json:"completed"`
|
||||
Progress float64 `json:"progress"`
|
||||
Eta float64 `json:"eta"`
|
||||
LivePreview string `json:"live_preview"`
|
||||
IDLivePreview int `json:"id_live_preview"`
|
||||
TextInfo interface{} `json:"textinfo"`
|
||||
}
|
||||
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(s.config.ApiURL + "/internal/progress")
|
||||
var cbReq = CBReq{TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId}
|
||||
if err != nil { // TODO: 这里可以考虑设置失败重试次数
|
||||
logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if response.IsErrorState() {
|
||||
bytes, _ := io.ReadAll(response.Body)
|
||||
logger.Error(string(bytes))
|
||||
return
|
||||
}
|
||||
|
||||
cbReq.ImageData = progressRes.LivePreview
|
||||
cbReq.Progress = int(progressRes.Progress * 100)
|
||||
s.callback(cbReq)
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) callback(data CBReq) {
|
||||
client := s.Clients.Get(data.SessionId)
|
||||
if data.Success { // 任务成功
|
||||
var job model.SdJob
|
||||
res := s.db.Where("id = ?", data.JobId).First(&job)
|
||||
if res.Error != nil {
|
||||
logger.Warn("非法任务:", res.Error)
|
||||
return
|
||||
}
|
||||
// 更新任务进度
|
||||
job.Progress = data.Progress
|
||||
// 更新任务 seed
|
||||
var params types.SdTaskParams
|
||||
err := utils.JsonDecode(job.Params, ¶ms)
|
||||
if err != nil {
|
||||
logger.Error("任务解析失败:", err)
|
||||
return
|
||||
}
|
||||
|
||||
params.Seed = data.Seed
|
||||
if data.ImageName != "" { // 下载图片
|
||||
imageURL := fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName)
|
||||
imageURL, err := s.uploadManager.GetUploadHandler().PutImg(imageURL, false)
|
||||
if err != nil {
|
||||
logger.Error("error with download img: ", err.Error())
|
||||
return
|
||||
}
|
||||
job.ImgURL = imageURL
|
||||
}
|
||||
|
||||
res = s.db.Updates(&job)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update job: ", res.Error)
|
||||
return
|
||||
}
|
||||
|
||||
var jobVo vo.SdJob
|
||||
err = utils.CopyObject(job, &jobVo)
|
||||
if err != nil {
|
||||
logger.Error("error with copy object: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if data.Progress < 100 {
|
||||
logger.Infof(data.ImageData)
|
||||
jobVo.ImgURL = data.ImageData
|
||||
}
|
||||
|
||||
// 推送任务到前端
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, jobVo)
|
||||
}
|
||||
} else { // 任务失败
|
||||
logger.Error("任务执行失败:", data.Message)
|
||||
// 删除任务
|
||||
s.db.Delete(&model.SdJob{Id: uint(data.JobId)})
|
||||
// 推送消息到前端
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, vo.SdJob{
|
||||
Id: uint(data.JobId),
|
||||
Progress: -1,
|
||||
Prompt: fmt.Sprintf("任务[%s]执行失败,已删除!", data.TaskId),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,19 +5,23 @@ import logger2 "chatplus/logger"
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type TaskInfo struct {
|
||||
TaskId string `json:"task_id"`
|
||||
Data interface{} `json:"data"`
|
||||
EventData interface{} `json:"event_data"`
|
||||
FnIndex int `json:"fn_index"`
|
||||
SessionHash string `json:"session_hash"`
|
||||
SessionId string
|
||||
JobId int
|
||||
TaskId string
|
||||
Data []interface{}
|
||||
EventData interface{}
|
||||
FnIndex int
|
||||
SessionHash string
|
||||
}
|
||||
|
||||
type CBReq struct {
|
||||
SessionId string
|
||||
JobId int
|
||||
TaskId string
|
||||
ImageName string
|
||||
ImageData string
|
||||
Progress int
|
||||
Seed string
|
||||
Seed int64
|
||||
Success bool
|
||||
Message string
|
||||
}
|
||||
@@ -41,164 +45,170 @@ var ParamKeys = map[string]int{
|
||||
}
|
||||
|
||||
const Text2ImgParamTemplate = `[
|
||||
"",
|
||||
"",
|
||||
"task(p1lk3n41saygmr8)",
|
||||
"a tiger sit on the window",
|
||||
"",
|
||||
[],
|
||||
30,
|
||||
"DPM++ SDE Karras",
|
||||
20,
|
||||
"Euler a",
|
||||
false,
|
||||
false,
|
||||
1,
|
||||
1,
|
||||
7.5,
|
||||
7,
|
||||
-1,
|
||||
-1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
false,
|
||||
512,
|
||||
512,
|
||||
true,
|
||||
128,
|
||||
128,
|
||||
false,
|
||||
0.7,
|
||||
2,
|
||||
"Latent",
|
||||
10,
|
||||
0,
|
||||
0,
|
||||
"Use same sampler",
|
||||
"",
|
||||
"",
|
||||
0,
|
||||
[],
|
||||
"None",
|
||||
false,
|
||||
"MultiDiffusion",
|
||||
false,
|
||||
10,
|
||||
1,
|
||||
1,
|
||||
64,
|
||||
false,
|
||||
true,
|
||||
1024,
|
||||
1024,
|
||||
96,
|
||||
96,
|
||||
48,
|
||||
4,
|
||||
1,
|
||||
"None",
|
||||
2,
|
||||
false,
|
||||
10,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
1536,
|
||||
96,
|
||||
false,
|
||||
false,
|
||||
"LoRA",
|
||||
"None",
|
||||
1,
|
||||
1,
|
||||
64,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
3072,
|
||||
192,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
"LoRA",
|
||||
"None",
|
||||
1,
|
||||
1,
|
||||
"LoRA",
|
||||
"None",
|
||||
1,
|
||||
1,
|
||||
"LoRA",
|
||||
"None",
|
||||
1,
|
||||
1,
|
||||
"LoRA",
|
||||
"None",
|
||||
1,
|
||||
1,
|
||||
null,
|
||||
"Refresh models",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
false,
|
||||
"",
|
||||
0.5,
|
||||
true,
|
||||
false,
|
||||
"",
|
||||
"Lerp",
|
||||
false,
|
||||
"🔄",
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
"positive",
|
||||
@@ -209,26 +219,26 @@ false,
|
||||
"",
|
||||
"Seed",
|
||||
"",
|
||||
[],
|
||||
"Nothing",
|
||||
"",
|
||||
[],
|
||||
"Nothing",
|
||||
"",
|
||||
[],
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
null,
|
||||
false,
|
||||
null,
|
||||
false,
|
||||
null,
|
||||
null,
|
||||
false,
|
||||
null,
|
||||
null,
|
||||
false,
|
||||
50
|
||||
50,
|
||||
[],
|
||||
"",
|
||||
"",
|
||||
""
|
||||
]`
|
||||
|
||||
Reference in New Issue
Block a user