package sd // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * Copyright 2023 The Geek-AI Authors. All rights reserved. // * Use of this source code is governed by a Apache-2.0 license // * that can be found in the LICENSE file. // * @Author yangjian102621@163.com // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( "fmt" "geekai/core/types" "geekai/service" "geekai/service/oss" "geekai/store" "geekai/store/model" "geekai/utils" "strings" "time" "github.com/imroc/req/v3" "gorm.io/gorm" ) // SD 绘画服务 type Service struct { httpClient *req.Client config types.StableDiffusionConfig taskQueue *store.RedisQueue notifyQueue *store.RedisQueue db *gorm.DB uploadManager *oss.UploaderManager name string // service name leveldb *store.LevelDB running bool // 运行状态 } func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service { config.ApiURL = strings.TrimRight(config.ApiURL, "/") return &Service{ name: name, config: config, httpClient: req.C(), taskQueue: taskQueue, notifyQueue: notifyQueue, db: db, leveldb: levelDB, uploadManager: manager, running: true, } } func (s *Service) Run() { logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name) for s.running { var task types.SdTask err := s.taskQueue.LPop(&task) if err != nil { logger.Errorf("taking task with error: %v", err) continue } // translate prompt if utils.HasChinese(task.Params.Prompt) { content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini") if err == nil { task.Params.Prompt = content } else { logger.Warnf("error with translate prompt: %v", err) } } // translate negative prompt if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) { content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini") if err == nil { task.Params.NegPrompt = content } else { logger.Warnf("error with translate prompt: %v", err) } } logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task) err = s.Txt2Img(task) if err != nil { logger.Error("绘画任务执行失败:", err.Error()) // update the task progress s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{ "progress": -1, "err_msg": err.Error(), }) // 通知前端,任务失败 s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed}) continue } } } func (s *Service) Stop() { s.running = false } // Txt2ImgReq 文生图请求实体 type Txt2ImgReq struct { Prompt string `json:"prompt"` NegativePrompt string `json:"negative_prompt"` Seed int64 `json:"seed,omitempty"` Steps int `json:"steps"` CfgScale float32 `json:"cfg_scale"` Width int `json:"width"` Height int `json:"height"` SamplerName string `json:"sampler_name"` Scheduler string `json:"scheduler"` EnableHr bool `json:"enable_hr,omitempty"` HrScale int `json:"hr_scale,omitempty"` HrUpscaler string `json:"hr_upscaler,omitempty"` HrSecondPassSteps int `json:"hr_second_pass_steps,omitempty"` DenoisingStrength float32 `json:"denoising_strength,omitempty"` ForceTaskId string `json:"force_task_id,omitempty"` } // Txt2ImgResp 文生图响应实体 type Txt2ImgResp struct { Images []string `json:"images"` Parameters struct { } `json:"parameters"` Info string `json:"info"` } // TaskProgressResp 任务进度响应实体 type TaskProgressResp struct { Progress float64 `json:"progress"` EtaRelative float64 `json:"eta_relative"` CurrentImage string `json:"current_image"` } // Txt2Img 文生图 API func (s *Service) Txt2Img(task types.SdTask) error { body := Txt2ImgReq{ Prompt: task.Params.Prompt, NegativePrompt: task.Params.NegPrompt, Steps: task.Params.Steps, CfgScale: task.Params.CfgScale, Width: task.Params.Width, Height: task.Params.Height, SamplerName: task.Params.Sampler, Scheduler: task.Params.Scheduler, ForceTaskId: task.Params.TaskId, } if task.Params.Seed > 0 { body.Seed = task.Params.Seed } if task.Params.HdFix { body.EnableHr = true body.HrScale = task.Params.HdScale body.HrUpscaler = task.Params.HdScaleAlg body.HrSecondPassSteps = task.Params.HdSteps body.DenoisingStrength = task.Params.HdRedrawRate } var res Txt2ImgResp var errChan = make(chan error) apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL) logger.Debugf("send image request to %s", apiURL) // send a request to sd api endpoint go func() { response, err := s.httpClient.R(). SetHeader("Authorization", s.config.ApiKey). SetBody(body). SetSuccessResult(&res). Post(apiURL) if err != nil { errChan <- err return } if response.IsErrorState() { errChan <- fmt.Errorf("error http code status: %v", response.Status) return } // 保存 Base64 图片 imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0]) if err != nil { errChan <- fmt.Errorf("error with upload image: %v", err) return } // 获取绘画真实的 seed var info map[string]interface{} err = utils.JsonDecode(res.Info, &info) if err != nil { errChan <- fmt.Errorf("error with decode task response: %v", err) return } task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1)) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params), Prompt: task.Params.Prompt}) errChan <- nil }() // waiting for task finish for { select { case err := <-errChan: if err != nil { return err } // task finished s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished}) // 从 leveldb 中删除预览图片数据 _ = s.leveldb.Delete(task.Params.TaskId) return nil default: err, resp := s.checkTaskProgress() // 更新任务进度 if err == nil && resp.Progress > 0 { s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) // 发送更新状态信号 s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running}) // 保存预览图片数据 if resp.CurrentImage != "" { _ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage) } } time.Sleep(time.Second) } } } // 执行任务 func (s *Service) checkTaskProgress() (error, *TaskProgressResp) { apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL) var res TaskProgressResp response, err := s.httpClient.R(). SetHeader("Authorization", s.config.ApiKey). SetSuccessResult(&res). Get(apiURL) if err != nil { return err, nil } if response.IsErrorState() { return fmt.Errorf("error http code status: %v", response.Status), nil } return nil, &res }