feat: add implements for stable diffusion service

This commit is contained in:
RockYang 2023-09-26 18:16:51 +08:00
parent c1143d7a6d
commit d51a724ade
8 changed files with 569 additions and 62 deletions

70
api/core/types/task.go Normal file
View File

@ -0,0 +1,70 @@
package types
// TaskType 任务类别
type TaskType string
func (t TaskType) String() string {
return string(t)
}
const (
TaskImage = TaskType("image")
TaskUpscale = TaskType("upscale")
TaskVariation = TaskType("variation")
TaskTxt2Img = TaskType("text2img")
)
// TaskSrc 任务来源
type TaskSrc string
const (
TaskSrcChat = TaskSrc("chat") // 来自聊天页面
TaskSrcImg = TaskSrc("img") // 专业绘画页面
)
// MjTask MidJourney 任务
type MjTask struct {
Id int `json:"id"`
SessionId string `json:"session_id"`
Src TaskSrc `json:"src"`
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
ChatId string `json:"chat_id,omitempty"`
RoleId int `json:"role_id,omitempty"`
Icon string `json:"icon,omitempty"`
Index int32 `json:"index,omitempty"`
MessageId string `json:"message_id,omitempty"`
MessageHash string `json:"message_hash,omitempty"`
RetryCount int `json:"retry_count"`
}
// SdParams stable diffusion 绘画参数
type SdParams struct {
TaskId string `json:"task_id"`
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt"`
Steps int `json:"steps"`
Sampler string `json:"sampler"`
FaceFix bool `json:"face_fix"`
CfgScale float32 `json:"cfg_scale"`
Seed int64 `json:"seed"`
Height int `json:"height"`
Width int `json:"width"`
HdFix bool `json:"hd_fix"`
HdRedrawRate float32 `json:"hd_redraw_rate"`
HdScale int `json:"hd_scale"`
HdScaleAlg string `json:"hd_scale_alg"`
HdSampleNum int `json:"hd_sample_num"`
}
type SdTask struct {
Id int `json:"id"`
SessionId string `json:"session_id"`
Src types.TaskSrc `json:"src"`
Type types.TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
Params types.SdParams `json:"params"`
RetryCount int `json:"retry_count"`
}

View File

@ -66,7 +66,7 @@ func NewMidJourneyHandler(
return &h
}
type notifyData struct {
type mjNotifyData struct {
MessageId string `json:"message_id"`
ReferenceId string `json:"reference_id"`
Image Image `json:"image"`
@ -98,7 +98,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
resp.NotAuth(c)
return
}
var data notifyData
var data mjNotifyData
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
resp.ERROR(c, types.InvalidArgs)
return
@ -122,14 +122,14 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
}
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (error, bool) {
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data mjNotifyData) (error, bool) {
taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
if err != nil { // 过期任务,丢弃
logger.Warn("任务已过期:", err)
return nil, true
}
var task service.MjTask
var task types.MjTask
err = utils.JsonDecode(taskString, &task)
if err != nil { // 非标准任务,丢弃
logger.Warn("任务解析失败:", err)
@ -143,7 +143,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
return nil, false
}
if task.Src == service.TaskSrcImg { // 绘画任务
if task.Src == types.TaskSrcImg { // 绘画任务
var job model.MidJourneyJob
res := h.db.Where("id = ?", task.Id).First(&job)
if res.Error != nil {
@ -191,7 +191,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
}
}
} else if task.Src == service.TaskSrcChat { // 聊天任务
} else if task.Src == types.TaskSrcChat { // 聊天任务
wsClient := h.App.MjTaskClients.Get(task.SessionId)
if data.Status == Finished {
if wsClient != nil && data.ReferenceId != "" {
@ -342,7 +342,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
job := model.MidJourneyJob{
Type: service.Image.String(),
Type: types.TaskImage.String(),
UserId: userId,
Progress: 0,
Prompt: prompt,
@ -353,11 +353,11 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return
}
h.mjService.PushTask(service.MjTask{
h.mjService.PushTask(types.MjTask{
Id: int(job.Id),
SessionId: data.SessionId,
Src: service.TaskSrcImg,
Type: service.Image,
Src: types.TaskSrcImg,
Type: types.TaskImage,
Prompt: prompt,
UserId: userId,
})
@ -401,10 +401,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
idValue, _ := c.Get(types.LoginUserID)
jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
src := service.TaskSrc(data.Src)
if src == service.TaskSrcImg {
src := types.TaskSrc(data.Src)
if src == types.TaskSrcImg {
job := model.MidJourneyJob{
Type: service.Upscale.String(),
Type: types.TaskUpscale.String(),
UserId: userId,
Hash: data.MessageHash,
Progress: 0,
@ -428,11 +428,11 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
}
}
}
h.mjService.PushTask(service.MjTask{
h.mjService.PushTask(types.MjTask{
Id: jobId,
SessionId: data.SessionId,
Src: src,
Type: service.Upscale,
Type: types.TaskUpscale,
Prompt: data.Prompt,
UserId: userId,
RoleId: data.RoleId,
@ -470,10 +470,10 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
idValue, _ := c.Get(types.LoginUserID)
jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
src := service.TaskSrc(data.Src)
if src == service.TaskSrcImg {
src := types.TaskSrc(data.Src)
if src == types.TaskSrcImg {
job := model.MidJourneyJob{
Type: service.Variation.String(),
Type: types.TaskVariation.String(),
UserId: userId,
ImgURL: "",
Hash: data.MessageHash,
@ -498,11 +498,11 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
}
}
}
h.mjService.PushTask(service.MjTask{
h.mjService.PushTask(types.MjTask{
Id: jobId,
SessionId: data.SessionId,
Src: src,
Type: service.Variation,
Type: types.TaskVariation,
Prompt: data.Prompt,
UserId: userId,
RoleId: data.RoleId,

315
api/handler/sd_handler.go Normal file
View File

@ -0,0 +1,315 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"encoding/base64"
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"strings"
"sync"
"time"
)
type SdJobHandler struct {
BaseHandler
redis *redis.Client
db *gorm.DB
mjService *service.MjService
uploaderManager *oss.UploaderManager
lock sync.Mutex
clients *types.LMap[string, *types.WsClient]
}
func NewSdJobHandler(
app *core.AppServer,
client *redis.Client,
db *gorm.DB,
manager *oss.UploaderManager,
mjService *service.MjService) *MidJourneyHandler {
h := MidJourneyHandler{
redis: client,
db: db,
uploaderManager: manager,
lock: sync.Mutex{},
mjService: mjService,
clients: types.NewLMap[string, *types.WsClient](),
}
h.App = app
return &h
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *SdJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
sessionId := c.Query("session_id")
client := types.NewWsClient(ws)
// 删除旧的连接
h.clients.Delete(sessionId)
h.clients.Put(sessionId, client)
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
}
type sdNotifyData struct {
TaskId string
ImageName string
ImageData string
Progress int
Seed string
Success bool
Message string
}
func (h *SdJobHandler) Notify(c *gin.Context) {
token := c.GetHeader("Authorization")
if token != h.App.Config.ExtConfig.Token {
resp.NotAuth(c)
return
}
var data sdNotifyData
if err := c.ShouldBindJSON(&data); err != nil || data.TaskId == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
logger.Debugf("收到 MidJourney 回调请求:%+v", data)
h.lock.Lock()
defer h.lock.Unlock()
err, finished := h.notifyHandler(c, data)
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 解除任务锁定
if finished && (data.Progress == 100) {
h.redis.Del(c, service.MjRunningJobKey)
}
resp.SUCCESS(c)
}
func (h *SdJobHandler) notifyHandler(c *gin.Context, data sdNotifyData) (error, bool) {
taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
if err != nil { // 过期任务,丢弃
logger.Warn("任务已过期:", err)
return nil, true
}
var task types.SdTask
err = utils.JsonDecode(taskString, &task)
if err != nil { // 非标准任务,丢弃
logger.Warn("任务解析失败:", err)
return nil, false
}
var job model.SdJob
res := h.db.Where("id = ?", task.Id).First(&job)
if res.Error != nil {
logger.Warn("非法任务:", res.Error)
return nil, false
}
job.Params = utils.JsonEncode(task.Params)
job.ReferenceId = data.ImageData
job.Progress = data.Progress
job.Prompt = data.Prompt
job.Hash = data.Image.Hash
// 任务完成,将最终的图片下载下来
if data.Progress == 100 {
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
if err != nil {
logger.Error("error with download img: ", err.Error())
return err, false
}
job.ImgURL = imgURL
} else {
// 临时图片直接保存,访问的时候使用代理进行转发
job.ImgURL = data.Image.URL
}
res = h.db.Updates(&job)
if res.Error != nil {
logger.Error("error with update job: ", res.Error)
return res.Error, false
}
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
if data.Progress < 100 {
image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL)
if err == nil {
jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
// 推送任务到前端
client := h.clients.Get(task.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
// 更新用户剩余绘图次数
if data.Progress == 100 {
h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
}
return nil, true
}
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
user, err := utils.GetLoginUser(c, h.db)
if err != nil {
resp.NotAuth(c)
return false
}
if user.ImgCalls <= 0 {
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
return false
}
return true
}
// Image 创建一个绘画任务
func (h *SdJobHandler) Image(c *gin.Context) {
var data struct {
SessionId string `json:"session_id"`
Prompt string `json:"prompt"`
Rate string `json:"rate"`
Model string `json:"model"`
Chaos int `json:"chaos"`
Raw bool `json:"raw"`
Seed int64 `json:"seed"`
Stylize int `json:"stylize"`
Img string `json:"img"`
Weight float32 `json:"weight"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if !h.checkLimits(c) {
return
}
var prompt = data.Prompt
if data.Rate != "" && !strings.Contains(prompt, "--ar") {
prompt += " --ar " + data.Rate
}
if data.Seed > 0 && !strings.Contains(prompt, "--seed") {
prompt += fmt.Sprintf(" --seed %d", data.Seed)
}
if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") {
prompt += fmt.Sprintf(" --s %d", data.Stylize)
}
if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
prompt += fmt.Sprintf(" --c %d", data.Chaos)
}
if data.Img != "" {
prompt = fmt.Sprintf("%s %s", data.Img, prompt)
if data.Weight > 0 {
prompt += fmt.Sprintf(" --iw %f", data.Weight)
}
}
if data.Raw {
prompt += " --style raw"
}
if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
prompt += data.Model
}
idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
job := model.MidJourneyJob{
Type: service.Image.String(),
UserId: userId,
Progress: 0,
Prompt: prompt,
CreatedAt: time.Now(),
}
if res := h.db.Create(&job); res.Error != nil {
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
return
}
h.mjService.PushTask(service.MjTask{
Id: int(job.Id),
SessionId: data.SessionId,
Src: service.TaskSrcImg,
Type: service.Image,
Prompt: prompt,
UserId: userId,
})
var jobVo vo.MidJourneyJob
err := utils.CopyObject(job, &jobVo)
if err == nil {
// 推送任务到前端
client := h.clients.Get(data.SessionId)
if client != nil {
utils.ReplyChunkMessage(client, jobVo)
}
}
resp.SUCCESS(c)
}
// JobList 获取 MJ 任务列表
func (h *SdJobHandler) JobList(c *gin.Context) {
status := h.GetInt(c, "status", 0)
var items []model.MidJourneyJob
var res *gorm.DB
userId, _ := c.Get(types.LoginUserID)
if status == 1 {
res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items)
} else {
res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items)
}
if res.Error != nil {
resp.ERROR(c, types.NoData)
return
}
var jobs = make([]vo.MidJourneyJob, 0)
for _, item := range items {
var job vo.MidJourneyJob
err := utils.CopyObject(item, &job)
if err != nil {
continue
}
if item.Progress < 100 {
// 30 分钟还没完成的任务直接删除
if time.Now().Sub(item.CreatedAt) > time.Minute*30 {
h.db.Delete(&item)
continue
}
if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL)
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
}
jobs = append(jobs, job)
}
resp.SUCCESS(c, jobs)
}

View File

@ -21,41 +21,6 @@ var logger = logger2.GetLogger()
const MjRunningJobKey = "MidJourney_Running_Job"
type TaskType string
func (t TaskType) String() string {
return string(t)
}
const (
Image = TaskType("image")
Upscale = TaskType("upscale")
Variation = TaskType("variation")
)
type TaskSrc string
const (
TaskSrcChat = TaskSrc("chat")
TaskSrcImg = TaskSrc("img")
)
type MjTask struct {
Id int `json:"id"`
SessionId string `json:"session_id"`
Src TaskSrc `json:"src"`
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
ChatId string `json:"chat_id,omitempty"`
RoleId int `json:"role_id,omitempty"`
Icon string `json:"icon,omitempty"`
Index int32 `json:"index,omitempty"`
MessageId string `json:"message_id,omitempty"`
MessageHash string `json:"message_hash,omitempty"`
RetryCount int `json:"retry_count"`
}
type MjService struct {
config types.ChatPlusExtConfig
client *req.Client
@ -78,11 +43,11 @@ func (s *MjService) Run() {
ctx := context.Background()
for {
_, err := s.redis.Get(ctx, MjRunningJobKey).Result()
if err == nil {
if err == nil { // 队列串行执行
time.Sleep(time.Second * 3)
continue
}
var task MjTask
var task types.MjTask
err = s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
@ -90,17 +55,17 @@ func (s *MjService) Run() {
}
logger.Infof("Consuming Task: %+v", task)
switch task.Type {
case Image:
case types.TaskImage:
err = s.image(task.Prompt)
break
case Upscale:
case types.TaskUpscale:
err = s.upscale(MjUpscaleReq{
Index: task.Index,
MessageId: task.MessageId,
MessageHash: task.MessageHash,
})
break
case Variation:
case types.TaskVariation:
err = s.variation(MjVariationReq{
Index: task.Index,
MessageId: task.MessageId,
@ -124,7 +89,7 @@ func (s *MjService) Run() {
}
}
func (s *MjService) PushTask(task MjTask) {
func (s *MjService) PushTask(task types.MjTask) {
logger.Infof("add a new MidJourney Task: %+v", task)
s.taskQueue.RPush(task)
}

95
api/service/sd_service.go Normal file
View File

@ -0,0 +1,95 @@
package service
import (
"chatplus/core/types"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
"context"
"errors"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/imroc/req/v3"
"gorm.io/gorm"
"time"
)
// SD 绘画服务
const SdRunningJobKey = "StableDiffusion_Running_Job"
type SdService struct {
config types.ChatPlusExtConfig
client *req.Client
taskQueue *store.RedisQueue
redis *redis.Client
db *gorm.DB
}
func NewSdService(appConfig *types.AppConfig, client *redis.Client, db *gorm.DB) *SdService {
return &SdService{
config: appConfig.ExtConfig,
redis: client,
db: db,
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", client),
client: req.C().SetTimeout(30 * time.Second)}
}
func (s *SdService) Run() {
logger.Info("Starting StableDiffusion job consumer.")
ctx := context.Background()
for {
_, err := s.redis.Get(ctx, SdRunningJobKey).Result()
if err == nil { // 队列串行执行
time.Sleep(time.Second * 3)
continue
}
var task types.SdTask
err = s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
logger.Infof("Consuming Task: %+v", task)
err = s.txt2img(task.Params)
if err != nil {
logger.Error("绘画任务执行失败:", err)
if task.RetryCount <= 5 {
s.taskQueue.RPush(task)
}
task.RetryCount += 1
time.Sleep(time.Second * 3)
continue
}
// 更新任务的执行状态
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
// 锁定任务执行通道直到任务超时5分钟
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*5)
}
}
func (s *SdService) PushTask(task types.SdTask) {
logger.Infof("add a new MidJourney Task: %+v", task)
s.taskQueue.RPush(task)
}
func (s *SdService) txt2img(params types.SdParams) error {
logger.Infof("SD 绘画参数:%+v", params)
url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("Authorization", s.config.Token).
SetHeader("Content-Type", "application/json").
SetBody(params).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
return fmt.Errorf("%v%v", r.String(), err)
}
if res.Code != types.Success {
return errors.New(res.Message)
}
return nil
}

20
api/store/model/sd_job.go Normal file
View File

@ -0,0 +1,20 @@
package model
import "time"
type SdJob struct {
Id uint `gorm:"primarykey;column:id"`
Type string
UserId int
TaskId string
ImgURL string
Progress int
Prompt string
Params string
Started bool
CreatedAt time.Time
}
func (SdJob) TableName() string {
return "chatgpt_sd_jobs"
}

19
api/store/vo/sd_job.go Normal file
View File

@ -0,0 +1,19 @@
package vo
import (
"chatplus/core/types"
"time"
)
type SdJob struct {
Id uint `json:"id"`
Type string `json:"type"`
UserId int `json:"user_id"`
TaskId string `json:"task_id"`
ImgURL string `json:"img_url"`
Params types.SdParams `json:"params"`
Progress int `json:"progress"`
Prompt string `json:"prompt"`
CreatedAt time.Time `json:"created_at"`
Started bool `json:"started"`
}

View File

@ -1,2 +1,25 @@
ALTER TABLE `chatgpt_mj_jobs` ADD `started` TINYINT(1) NOT NULL DEFAULT '0' COMMENT '任务是否开始' AFTER `progress`;
UPDATE `chatgpt_mj_jobs` SET started = 1
UPDATE `chatgpt_mj_jobs` SET started = 1
-- 创建 SD 绘图任务表
CREATE TABLE `chatgpt_sd_jobs` (
`id` int NOT NULL,
`user_id` int NOT NULL COMMENT '用户 ID',
`type` varchar(20) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT 'txt2img' COMMENT '任务类别',
`task_id` char(30) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '任务 ID',
`prompt` varchar(2000) NOT NULL COMMENT '会话提示词',
`img_url` varchar(255) DEFAULT NULL COMMENT '图片URL',
`params` text CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci COMMENT '绘画参数json',
`progress` smallint DEFAULT '0' COMMENT '任务进度',
`started` tinyint(1) NOT NULL DEFAULT '0' COMMENT '任务是否开始',
`created_at` datetime NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='StableDiffusion 任务表';
--
-- 表的索引 `chatgpt_sd_jobs`
--
ALTER TABLE `chatgpt_sd_jobs`
ADD PRIMARY KEY (`id`),
ADD UNIQUE KEY `task_id` (`task_id`);
ALTER TABLE `chatgpt_sd_jobs`
MODIFY `id` int NOT NULL AUTO_INCREMENT;