From d51a724ade1213ef0904dee30dbe9176a7446f8f Mon Sep 17 00:00:00 2001 From: RockYang Date: Tue, 26 Sep 2023 18:16:51 +0800 Subject: [PATCH] feat: add implements for stable diffusion service --- api/core/types/task.go | 70 +++++++++ api/handler/mj_handler.go | 40 ++--- api/handler/sd_handler.go | 315 +++++++++++++++++++++++++++++++++++++ api/service/mj_service.go | 47 +----- api/service/sd_service.go | 95 +++++++++++ api/store/model/sd_job.go | 20 +++ api/store/vo/sd_job.go | 19 +++ database/update-v3.1.4.sql | 25 ++- 8 files changed, 569 insertions(+), 62 deletions(-) create mode 100644 api/core/types/task.go create mode 100644 api/handler/sd_handler.go create mode 100644 api/service/sd_service.go create mode 100644 api/store/model/sd_job.go create mode 100644 api/store/vo/sd_job.go diff --git a/api/core/types/task.go b/api/core/types/task.go new file mode 100644 index 00000000..4ced8ec7 --- /dev/null +++ b/api/core/types/task.go @@ -0,0 +1,70 @@ +package types + +// TaskType 任务类别 +type TaskType string + +func (t TaskType) String() string { + return string(t) +} + +const ( + TaskImage = TaskType("image") + TaskUpscale = TaskType("upscale") + TaskVariation = TaskType("variation") + TaskTxt2Img = TaskType("text2img") +) + +// TaskSrc 任务来源 +type TaskSrc string + +const ( + TaskSrcChat = TaskSrc("chat") // 来自聊天页面 + TaskSrcImg = TaskSrc("img") // 专业绘画页面 +) + +// MjTask MidJourney 任务 +type MjTask struct { + Id int `json:"id"` + SessionId string `json:"session_id"` + Src TaskSrc `json:"src"` + Type TaskType `json:"type"` + UserId int `json:"user_id"` + Prompt string `json:"prompt,omitempty"` + ChatId string `json:"chat_id,omitempty"` + RoleId int `json:"role_id,omitempty"` + Icon string `json:"icon,omitempty"` + Index int32 `json:"index,omitempty"` + MessageId string `json:"message_id,omitempty"` + MessageHash string `json:"message_hash,omitempty"` + RetryCount int `json:"retry_count"` +} + +// SdParams stable diffusion 绘画参数 +type SdParams struct { + TaskId string `json:"task_id"` + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt"` + Steps int `json:"steps"` + Sampler string `json:"sampler"` + FaceFix bool `json:"face_fix"` + CfgScale float32 `json:"cfg_scale"` + Seed int64 `json:"seed"` + Height int `json:"height"` + Width int `json:"width"` + HdFix bool `json:"hd_fix"` + HdRedrawRate float32 `json:"hd_redraw_rate"` + HdScale int `json:"hd_scale"` + HdScaleAlg string `json:"hd_scale_alg"` + HdSampleNum int `json:"hd_sample_num"` +} + +type SdTask struct { + Id int `json:"id"` + SessionId string `json:"session_id"` + Src types.TaskSrc `json:"src"` + Type types.TaskType `json:"type"` + UserId int `json:"user_id"` + Prompt string `json:"prompt,omitempty"` + Params types.SdParams `json:"params"` + RetryCount int `json:"retry_count"` +} diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 22587f98..d0b18fc5 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -66,7 +66,7 @@ func NewMidJourneyHandler( return &h } -type notifyData struct { +type mjNotifyData struct { MessageId string `json:"message_id"` ReferenceId string `json:"reference_id"` Image Image `json:"image"` @@ -98,7 +98,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { resp.NotAuth(c) return } - var data notifyData + var data mjNotifyData if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" { resp.ERROR(c, types.InvalidArgs) return @@ -122,14 +122,14 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { } -func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (error, bool) { +func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data mjNotifyData) (error, bool) { taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result() if err != nil { // 过期任务,丢弃 logger.Warn("任务已过期:", err) return nil, true } - var task service.MjTask + var task types.MjTask err = utils.JsonDecode(taskString, &task) if err != nil { // 非标准任务,丢弃 logger.Warn("任务解析失败:", err) @@ -143,7 +143,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro return nil, false } - if task.Src == service.TaskSrcImg { // 绘画任务 + if task.Src == types.TaskSrcImg { // 绘画任务 var job model.MidJourneyJob res := h.db.Where("id = ?", task.Id).First(&job) if res.Error != nil { @@ -191,7 +191,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro } } - } else if task.Src == service.TaskSrcChat { // 聊天任务 + } else if task.Src == types.TaskSrcChat { // 聊天任务 wsClient := h.App.MjTaskClients.Get(task.SessionId) if data.Status == Finished { if wsClient != nil && data.ReferenceId != "" { @@ -342,7 +342,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { idValue, _ := c.Get(types.LoginUserID) userId := utils.IntValue(utils.InterfaceToString(idValue), 0) job := model.MidJourneyJob{ - Type: service.Image.String(), + Type: types.TaskImage.String(), UserId: userId, Progress: 0, Prompt: prompt, @@ -353,11 +353,11 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { return } - h.mjService.PushTask(service.MjTask{ + h.mjService.PushTask(types.MjTask{ Id: int(job.Id), SessionId: data.SessionId, - Src: service.TaskSrcImg, - Type: service.Image, + Src: types.TaskSrcImg, + Type: types.TaskImage, Prompt: prompt, UserId: userId, }) @@ -401,10 +401,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { idValue, _ := c.Get(types.LoginUserID) jobId := 0 userId := utils.IntValue(utils.InterfaceToString(idValue), 0) - src := service.TaskSrc(data.Src) - if src == service.TaskSrcImg { + src := types.TaskSrc(data.Src) + if src == types.TaskSrcImg { job := model.MidJourneyJob{ - Type: service.Upscale.String(), + Type: types.TaskUpscale.String(), UserId: userId, Hash: data.MessageHash, Progress: 0, @@ -428,11 +428,11 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { } } } - h.mjService.PushTask(service.MjTask{ + h.mjService.PushTask(types.MjTask{ Id: jobId, SessionId: data.SessionId, Src: src, - Type: service.Upscale, + Type: types.TaskUpscale, Prompt: data.Prompt, UserId: userId, RoleId: data.RoleId, @@ -470,10 +470,10 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { idValue, _ := c.Get(types.LoginUserID) jobId := 0 userId := utils.IntValue(utils.InterfaceToString(idValue), 0) - src := service.TaskSrc(data.Src) - if src == service.TaskSrcImg { + src := types.TaskSrc(data.Src) + if src == types.TaskSrcImg { job := model.MidJourneyJob{ - Type: service.Variation.String(), + Type: types.TaskVariation.String(), UserId: userId, ImgURL: "", Hash: data.MessageHash, @@ -498,11 +498,11 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { } } } - h.mjService.PushTask(service.MjTask{ + h.mjService.PushTask(types.MjTask{ Id: jobId, SessionId: data.SessionId, Src: src, - Type: service.Variation, + Type: types.TaskVariation, Prompt: data.Prompt, UserId: userId, RoleId: data.RoleId, diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go new file mode 100644 index 00000000..fbeca078 --- /dev/null +++ b/api/handler/sd_handler.go @@ -0,0 +1,315 @@ +package handler + +import ( + "chatplus/core" + "chatplus/core/types" + "chatplus/service" + "chatplus/service/oss" + "chatplus/store/model" + "chatplus/store/vo" + "chatplus/utils" + "chatplus/utils/resp" + "encoding/base64" + "fmt" + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" + "github.com/gorilla/websocket" + "gorm.io/gorm" + "net/http" + "strings" + "sync" + "time" +) + +type SdJobHandler struct { + BaseHandler + redis *redis.Client + db *gorm.DB + mjService *service.MjService + uploaderManager *oss.UploaderManager + lock sync.Mutex + clients *types.LMap[string, *types.WsClient] +} + +func NewSdJobHandler( + app *core.AppServer, + client *redis.Client, + db *gorm.DB, + manager *oss.UploaderManager, + mjService *service.MjService) *MidJourneyHandler { + h := MidJourneyHandler{ + redis: client, + db: db, + uploaderManager: manager, + lock: sync.Mutex{}, + mjService: mjService, + clients: types.NewLMap[string, *types.WsClient](), + } + h.App = app + return &h +} + +// Client WebSocket 客户端,用于通知任务状态变更 +func (h *SdJobHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + return + } + + sessionId := c.Query("session_id") + client := types.NewWsClient(ws) + // 删除旧的连接 + h.clients.Delete(sessionId) + h.clients.Put(sessionId, client) + logger.Infof("New websocket connected, IP: %s", c.ClientIP()) +} + +type sdNotifyData struct { + TaskId string + ImageName string + ImageData string + Progress int + Seed string + Success bool + Message string +} + +func (h *SdJobHandler) Notify(c *gin.Context) { + token := c.GetHeader("Authorization") + if token != h.App.Config.ExtConfig.Token { + resp.NotAuth(c) + return + } + var data sdNotifyData + if err := c.ShouldBindJSON(&data); err != nil || data.TaskId == "" { + resp.ERROR(c, types.InvalidArgs) + return + } + logger.Debugf("收到 MidJourney 回调请求:%+v", data) + + h.lock.Lock() + defer h.lock.Unlock() + + err, finished := h.notifyHandler(c, data) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + // 解除任务锁定 + if finished && (data.Progress == 100) { + h.redis.Del(c, service.MjRunningJobKey) + } + resp.SUCCESS(c) + +} + +func (h *SdJobHandler) notifyHandler(c *gin.Context, data sdNotifyData) (error, bool) { + taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result() + if err != nil { // 过期任务,丢弃 + logger.Warn("任务已过期:", err) + return nil, true + } + + var task types.SdTask + err = utils.JsonDecode(taskString, &task) + if err != nil { // 非标准任务,丢弃 + logger.Warn("任务解析失败:", err) + return nil, false + } + + var job model.SdJob + res := h.db.Where("id = ?", task.Id).First(&job) + if res.Error != nil { + logger.Warn("非法任务:", res.Error) + return nil, false + } + job.Params = utils.JsonEncode(task.Params) + job.ReferenceId = data.ImageData + job.Progress = data.Progress + job.Prompt = data.Prompt + job.Hash = data.Image.Hash + + // 任务完成,将最终的图片下载下来 + if data.Progress == 100 { + imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL) + if err != nil { + logger.Error("error with download img: ", err.Error()) + return err, false + } + job.ImgURL = imgURL + } else { + // 临时图片直接保存,访问的时候使用代理进行转发 + job.ImgURL = data.Image.URL + } + res = h.db.Updates(&job) + if res.Error != nil { + logger.Error("error with update job: ", res.Error) + return res.Error, false + } + + var jobVo vo.MidJourneyJob + err := utils.CopyObject(job, &jobVo) + if err == nil { + if data.Progress < 100 { + image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL) + if err == nil { + jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) + } + } + + // 推送任务到前端 + client := h.clients.Get(task.SessionId) + if client != nil { + utils.ReplyChunkMessage(client, jobVo) + } + } + + // 更新用户剩余绘图次数 + if data.Progress == 100 { + h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) + } + + return nil, true +} + +func (h *SdJobHandler) checkLimits(c *gin.Context) bool { + user, err := utils.GetLoginUser(c, h.db) + if err != nil { + resp.NotAuth(c) + return false + } + + if user.ImgCalls <= 0 { + resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!") + return false + } + + return true + +} + +// Image 创建一个绘画任务 +func (h *SdJobHandler) Image(c *gin.Context) { + var data struct { + SessionId string `json:"session_id"` + Prompt string `json:"prompt"` + Rate string `json:"rate"` + Model string `json:"model"` + Chaos int `json:"chaos"` + Raw bool `json:"raw"` + Seed int64 `json:"seed"` + Stylize int `json:"stylize"` + Img string `json:"img"` + Weight float32 `json:"weight"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + if !h.checkLimits(c) { + return + } + + var prompt = data.Prompt + if data.Rate != "" && !strings.Contains(prompt, "--ar") { + prompt += " --ar " + data.Rate + } + if data.Seed > 0 && !strings.Contains(prompt, "--seed") { + prompt += fmt.Sprintf(" --seed %d", data.Seed) + } + if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") { + prompt += fmt.Sprintf(" --s %d", data.Stylize) + } + if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") { + prompt += fmt.Sprintf(" --c %d", data.Chaos) + } + if data.Img != "" { + prompt = fmt.Sprintf("%s %s", data.Img, prompt) + if data.Weight > 0 { + prompt += fmt.Sprintf(" --iw %f", data.Weight) + } + } + if data.Raw { + prompt += " --style raw" + } + if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") { + prompt += data.Model + } + + idValue, _ := c.Get(types.LoginUserID) + userId := utils.IntValue(utils.InterfaceToString(idValue), 0) + job := model.MidJourneyJob{ + Type: service.Image.String(), + UserId: userId, + Progress: 0, + Prompt: prompt, + CreatedAt: time.Now(), + } + if res := h.db.Create(&job); res.Error != nil { + resp.ERROR(c, "添加任务失败:"+res.Error.Error()) + return + } + + h.mjService.PushTask(service.MjTask{ + Id: int(job.Id), + SessionId: data.SessionId, + Src: service.TaskSrcImg, + Type: service.Image, + Prompt: prompt, + UserId: userId, + }) + + var jobVo vo.MidJourneyJob + err := utils.CopyObject(job, &jobVo) + if err == nil { + // 推送任务到前端 + client := h.clients.Get(data.SessionId) + if client != nil { + utils.ReplyChunkMessage(client, jobVo) + } + } + resp.SUCCESS(c) +} + +// JobList 获取 MJ 任务列表 +func (h *SdJobHandler) JobList(c *gin.Context) { + status := h.GetInt(c, "status", 0) + var items []model.MidJourneyJob + var res *gorm.DB + userId, _ := c.Get(types.LoginUserID) + if status == 1 { + res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items) + } else { + res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items) + } + if res.Error != nil { + resp.ERROR(c, types.NoData) + return + } + + var jobs = make([]vo.MidJourneyJob, 0) + for _, item := range items { + var job vo.MidJourneyJob + err := utils.CopyObject(item, &job) + if err != nil { + continue + } + if item.Progress < 100 { + // 30 分钟还没完成的任务直接删除 + if time.Now().Sub(item.CreatedAt) > time.Minute*30 { + h.db.Delete(&item) + continue + } + if item.ImgURL != "" { // 正在运行中任务使用代理访问图片 + image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL) + if err == nil { + job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) + } + } + } + jobs = append(jobs, job) + } + resp.SUCCESS(c, jobs) +} diff --git a/api/service/mj_service.go b/api/service/mj_service.go index 16e38827..393e8e91 100644 --- a/api/service/mj_service.go +++ b/api/service/mj_service.go @@ -21,41 +21,6 @@ var logger = logger2.GetLogger() const MjRunningJobKey = "MidJourney_Running_Job" -type TaskType string - -func (t TaskType) String() string { - return string(t) -} - -const ( - Image = TaskType("image") - Upscale = TaskType("upscale") - Variation = TaskType("variation") -) - -type TaskSrc string - -const ( - TaskSrcChat = TaskSrc("chat") - TaskSrcImg = TaskSrc("img") -) - -type MjTask struct { - Id int `json:"id"` - SessionId string `json:"session_id"` - Src TaskSrc `json:"src"` - Type TaskType `json:"type"` - UserId int `json:"user_id"` - Prompt string `json:"prompt,omitempty"` - ChatId string `json:"chat_id,omitempty"` - RoleId int `json:"role_id,omitempty"` - Icon string `json:"icon,omitempty"` - Index int32 `json:"index,omitempty"` - MessageId string `json:"message_id,omitempty"` - MessageHash string `json:"message_hash,omitempty"` - RetryCount int `json:"retry_count"` -} - type MjService struct { config types.ChatPlusExtConfig client *req.Client @@ -78,11 +43,11 @@ func (s *MjService) Run() { ctx := context.Background() for { _, err := s.redis.Get(ctx, MjRunningJobKey).Result() - if err == nil { + if err == nil { // 队列串行执行 time.Sleep(time.Second * 3) continue } - var task MjTask + var task types.MjTask err = s.taskQueue.LPop(&task) if err != nil { logger.Errorf("taking task with error: %v", err) @@ -90,17 +55,17 @@ func (s *MjService) Run() { } logger.Infof("Consuming Task: %+v", task) switch task.Type { - case Image: + case types.TaskImage: err = s.image(task.Prompt) break - case Upscale: + case types.TaskUpscale: err = s.upscale(MjUpscaleReq{ Index: task.Index, MessageId: task.MessageId, MessageHash: task.MessageHash, }) break - case Variation: + case types.TaskVariation: err = s.variation(MjVariationReq{ Index: task.Index, MessageId: task.MessageId, @@ -124,7 +89,7 @@ func (s *MjService) Run() { } } -func (s *MjService) PushTask(task MjTask) { +func (s *MjService) PushTask(task types.MjTask) { logger.Infof("add a new MidJourney Task: %+v", task) s.taskQueue.RPush(task) } diff --git a/api/service/sd_service.go b/api/service/sd_service.go new file mode 100644 index 00000000..a29b42ed --- /dev/null +++ b/api/service/sd_service.go @@ -0,0 +1,95 @@ +package service + +import ( + "chatplus/core/types" + "chatplus/store" + "chatplus/store/model" + "chatplus/utils" + "context" + "errors" + "fmt" + "github.com/go-redis/redis/v8" + "github.com/imroc/req/v3" + "gorm.io/gorm" + "time" +) + +// SD 绘画服务 + +const SdRunningJobKey = "StableDiffusion_Running_Job" + +type SdService struct { + config types.ChatPlusExtConfig + client *req.Client + taskQueue *store.RedisQueue + redis *redis.Client + db *gorm.DB +} + +func NewSdService(appConfig *types.AppConfig, client *redis.Client, db *gorm.DB) *SdService { + return &SdService{ + config: appConfig.ExtConfig, + redis: client, + db: db, + taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", client), + client: req.C().SetTimeout(30 * time.Second)} +} + +func (s *SdService) Run() { + logger.Info("Starting StableDiffusion job consumer.") + ctx := context.Background() + for { + _, err := s.redis.Get(ctx, SdRunningJobKey).Result() + if err == nil { // 队列串行执行 + time.Sleep(time.Second * 3) + continue + } + var task types.SdTask + err = s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) + continue + } + logger.Infof("Consuming Task: %+v", task) + err = s.txt2img(task.Params) + if err != nil { + logger.Error("绘画任务执行失败:", err) + if task.RetryCount <= 5 { + s.taskQueue.RPush(task) + } + task.RetryCount += 1 + time.Sleep(time.Second * 3) + continue + } + + // 更新任务的执行状态 + s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true) + // 锁定任务执行通道,直到任务超时(5分钟) + s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*5) + } +} + +func (s *SdService) PushTask(task types.SdTask) { + logger.Infof("add a new MidJourney Task: %+v", task) + s.taskQueue.RPush(task) +} + +func (s *SdService) txt2img(params types.SdParams) error { + logger.Infof("SD 绘画参数:%+v", params) + url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL) + var res types.BizVo + r, err := s.client.R(). + SetHeader("Authorization", s.config.Token). + SetHeader("Content-Type", "application/json"). + SetBody(params). + SetSuccessResult(&res).Post(url) + if err != nil || r.IsErrorState() { + return fmt.Errorf("%v%v", r.String(), err) + } + + if res.Code != types.Success { + return errors.New(res.Message) + } + + return nil +} diff --git a/api/store/model/sd_job.go b/api/store/model/sd_job.go new file mode 100644 index 00000000..473321b9 --- /dev/null +++ b/api/store/model/sd_job.go @@ -0,0 +1,20 @@ +package model + +import "time" + +type SdJob struct { + Id uint `gorm:"primarykey;column:id"` + Type string + UserId int + TaskId string + ImgURL string + Progress int + Prompt string + Params string + Started bool + CreatedAt time.Time +} + +func (SdJob) TableName() string { + return "chatgpt_sd_jobs" +} diff --git a/api/store/vo/sd_job.go b/api/store/vo/sd_job.go new file mode 100644 index 00000000..b91cad69 --- /dev/null +++ b/api/store/vo/sd_job.go @@ -0,0 +1,19 @@ +package vo + +import ( + "chatplus/core/types" + "time" +) + +type SdJob struct { + Id uint `json:"id"` + Type string `json:"type"` + UserId int `json:"user_id"` + TaskId string `json:"task_id"` + ImgURL string `json:"img_url"` + Params types.SdParams `json:"params"` + Progress int `json:"progress"` + Prompt string `json:"prompt"` + CreatedAt time.Time `json:"created_at"` + Started bool `json:"started"` +} diff --git a/database/update-v3.1.4.sql b/database/update-v3.1.4.sql index 5ad368f5..24fb8e5b 100644 --- a/database/update-v3.1.4.sql +++ b/database/update-v3.1.4.sql @@ -1,2 +1,25 @@ ALTER TABLE `chatgpt_mj_jobs` ADD `started` TINYINT(1) NOT NULL DEFAULT '0' COMMENT '任务是否开始' AFTER `progress`; -UPDATE `chatgpt_mj_jobs` SET started = 1 \ No newline at end of file +UPDATE `chatgpt_mj_jobs` SET started = 1 + +-- 创建 SD 绘图任务表 +CREATE TABLE `chatgpt_sd_jobs` ( + `id` int NOT NULL, + `user_id` int NOT NULL COMMENT '用户 ID', + `type` varchar(20) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT 'txt2img' COMMENT '任务类别', + `task_id` char(30) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '任务 ID', + `prompt` varchar(2000) NOT NULL COMMENT '会话提示词', + `img_url` varchar(255) DEFAULT NULL COMMENT '图片URL', + `params` text CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci COMMENT '绘画参数json', + `progress` smallint DEFAULT '0' COMMENT '任务进度', + `started` tinyint(1) NOT NULL DEFAULT '0' COMMENT '任务是否开始', + `created_at` datetime NOT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='StableDiffusion 任务表'; +-- +-- 表的索引 `chatgpt_sd_jobs` +-- +ALTER TABLE `chatgpt_sd_jobs` + ADD PRIMARY KEY (`id`), + ADD UNIQUE KEY `task_id` (`task_id`); + +ALTER TABLE `chatgpt_sd_jobs` + MODIFY `id` int NOT NULL AUTO_INCREMENT; \ No newline at end of file