diff --git a/api/core/types/config.go b/api/core/types/config.go index a3ae1b04..39990fb5 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -56,10 +56,10 @@ type MidJourneyConfig struct { } type StableDiffusionConfig struct { - Enabled bool - ApiURL string - ApiKey string - Txt2ImgJsonPath string + Enabled bool + Model string // 模型名称 + ApiURL string + ApiKey string } type MidJourneyPlusConfig struct { diff --git a/api/main.go b/api/main.go index d1235e2b..233f38e9 100644 --- a/api/main.go +++ b/api/main.go @@ -175,6 +175,12 @@ func main() { // Stable Diffusion 机器人 fx.Provide(sd.NewServicePool), + fx.Invoke(func(pool *sd.ServicePool) { + if pool.HasAvailableService() { + pool.CheckTaskNotify() + pool.CheckTaskStatus() + } + }), fx.Provide(payment.NewAlipayService), fx.Provide(payment.NewHuPiPay), diff --git a/api/res/sd/text2img.json b/api/res/sd/text2img.json deleted file mode 100644 index c15cc35e..00000000 --- a/api/res/sd/text2img.json +++ /dev/null @@ -1,80 +0,0 @@ -{ - "data": [ - "task(cxvkpawy8onnfti)", - "a cute girl", - "", - [], - 20, - "DPM++ 2M Karras", - 1, - 1, - 7, - 512, - 512, - false, - 0.7, - 2, - "Latent", - 0, - 0, - 0, - "Use same checkpoint", - "Use same sampler", - "", - "", - [], - "None", - false, - "", - 0.8, - -1, - false, - -1, - 0, - 0, - 0, - null, - null, - null, - null, - false, - false, - "positive", - "comma", - 0, - false, - false, - "", - "Seed", - "", - [], - "Nothing", - "", - [], - "Nothing", - "", - [], - true, - false, - false, - false, - 0, - null, - null, - false, - null, - null, - false, - null, - null, - false, - 50, - [], - "", - "", - "" - ], - "event_data": null, - "fn_index": 446, - "session_hash": "nk5noh1rz1o" -} \ No newline at end of file diff --git a/api/service/oss/aliyun_oss.go b/api/service/oss/aliyun_oss.go index 28ec6882..d713d278 100644 --- a/api/service/oss/aliyun_oss.go +++ b/api/service/oss/aliyun_oss.go @@ -4,6 +4,7 @@ import ( "bytes" "chatplus/core/types" "chatplus/utils" + "encoding/base64" "fmt" "net/url" "path/filepath" @@ -101,6 +102,20 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) { return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil } +func (s AliYunOss) PutBase64(base64Img string) (string, error) { + imageData, err := base64.StdEncoding.DecodeString(base64Img) + if err != nil { + return "", fmt.Errorf("error decoding base64:%v", err) + } + objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro()) + // 上传文件字节数据 + err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData)) + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil +} + func (s AliYunOss) Delete(fileURL string) error { var objectKey string if strings.HasPrefix(fileURL, "http") { diff --git a/api/service/oss/localstorage.go b/api/service/oss/localstorage.go index 184b6011..aeff4427 100644 --- a/api/service/oss/localstorage.go +++ b/api/service/oss/localstorage.go @@ -3,13 +3,13 @@ package oss import ( "chatplus/core/types" "chatplus/utils" + "encoding/base64" "fmt" + "github.com/gin-gonic/gin" "net/url" "os" "path/filepath" "strings" - - "github.com/gin-gonic/gin" ) type LocalStorage struct { @@ -73,6 +73,20 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) { return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil } +func (s LocalStorage) PutBase64(base64Img string) (string, error) { + imageData, err := base64.StdEncoding.DecodeString(base64Img) + if err != nil { + return "", fmt.Errorf("error decoding base64:%v", err) + } + filePath, err := utils.GenUploadPath(s.config.BasePath, "", true) + err = os.WriteFile(filePath, imageData, 0644) + if err != nil { + return "", fmt.Errorf("error writing to file:%v", err) + } + + return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil +} + func (s LocalStorage) Delete(fileURL string) error { if _, err := os.Stat(fileURL); err == nil { return os.Remove(fileURL) diff --git a/api/service/oss/minio_oss.go b/api/service/oss/minio_oss.go index ba11d333..75a9cfbb 100644 --- a/api/service/oss/minio_oss.go +++ b/api/service/oss/minio_oss.go @@ -4,6 +4,7 @@ import ( "chatplus/core/types" "chatplus/utils" "context" + "encoding/base64" "fmt" "net/url" "path/filepath" @@ -96,6 +97,25 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) { }, nil } +func (s MiniOss) PutBase64(base64Img string) (string, error) { + imageData, err := base64.StdEncoding.DecodeString(base64Img) + if err != nil { + return "", fmt.Errorf("error decoding base64:%v", err) + } + objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro()) + info, err := s.client.PutObject( + context.Background(), + s.config.Bucket, + objectKey, + strings.NewReader(string(imageData)), + int64(len(imageData)), + minio.PutObjectOptions{ContentType: "image/png"}) + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil +} + func (s MiniOss) Delete(fileURL string) error { var objectKey string if strings.HasPrefix(fileURL, "http") { diff --git a/api/service/oss/qiniu_oss.go b/api/service/oss/qiniu_oss.go index 84aa941c..79ea5d1c 100644 --- a/api/service/oss/qiniu_oss.go +++ b/api/service/oss/qiniu_oss.go @@ -5,6 +5,7 @@ import ( "chatplus/core/types" "chatplus/utils" "context" + "encoding/base64" "fmt" "net/url" "path/filepath" @@ -112,6 +113,22 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) { return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil } +func (s QinNiuOss) PutBase64(base64Img string) (string, error) { + imageData, err := base64.StdEncoding.DecodeString(base64Img) + if err != nil { + return "", fmt.Errorf("error decoding base64:%v", err) + } + objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro()) + ret := storage.PutRet{} + extra := storage.PutExtra{} + // 上传文件字节数据 + err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), objectKey, bytes.NewReader(imageData), int64(len(imageData)), &extra) + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil +} + func (s QinNiuOss) Delete(fileURL string) error { var objectKey string if strings.HasPrefix(fileURL, "http") { diff --git a/api/service/oss/uploader.go b/api/service/oss/uploader.go index ce410d02..be6a6f53 100644 --- a/api/service/oss/uploader.go +++ b/api/service/oss/uploader.go @@ -17,5 +17,6 @@ type File struct { type Uploader interface { PutFile(ctx *gin.Context, name string) (File, error) PutImg(imageURL string, useProxy bool) (string, error) + PutBase64(imageData string) (string, error) Delete(fileURL string) error } diff --git a/api/service/sd/pool.go b/api/service/sd/pool.go index e52baed9..2661edd1 100644 --- a/api/service/sd/pool.go +++ b/api/service/sd/pool.go @@ -25,14 +25,14 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli) notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli) // create mj client and service - for k, config := range appConfig.SdConfigs { + for _, config := range appConfig.SdConfigs { if config.Enabled == false { continue } // create sd service - name := fmt.Sprintf("StableDifffusion Service-%d", k) - service := NewService(name, 1, 300, config, taskQueue, notifyQueue, db, manager) + name := fmt.Sprintf("StableDifffusion Service-%s", config.Model) + service := NewService(name, config, taskQueue, notifyQueue, db, manager) // run sd service go func() { service.Run() @@ -58,6 +58,7 @@ func (p *ServicePool) PushTask(task types.SdTask) { func (p *ServicePool) CheckTaskNotify() { go func() { + logger.Info("Running Stable-Diffusion task notify checking ...") for { var userId uint err := p.notifyQueue.LPop(&userId) @@ -79,6 +80,7 @@ func (p *ServicePool) CheckTaskNotify() { // CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务 func (p *ServicePool) CheckTaskStatus() { go func() { + logger.Info("Running Stable-Diffusion task status checking ...") for { var jobs []model.SdJob res := p.db.Where("progress < ?", 100).Find(&jobs) diff --git a/api/service/sd/service.go b/api/service/sd/service.go index 074c34c5..6cae7b0a 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -6,59 +6,40 @@ import ( "chatplus/store" "chatplus/store/model" "chatplus/utils" - "encoding/json" "fmt" - "io" - "os" - "strconv" - "sync/atomic" - "time" - "github.com/imroc/req/v3" "gorm.io/gorm" + "strings" + "time" ) // SD 绘画服务 type Service struct { - httpClient *req.Client - config types.StableDiffusionConfig - taskQueue *store.RedisQueue - notifyQueue *store.RedisQueue - db *gorm.DB - uploadManager *oss.UploaderManager - name string // service name - maxHandleTaskNum int32 // max task number current service can handle - handledTaskNum int32 // already handled task number - taskStartTimes map[int]time.Time // task start time, to check if the task is timeout - taskTimeout int64 + httpClient *req.Client + config types.StableDiffusionConfig + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + db *gorm.DB + uploadManager *oss.UploaderManager + name string // service name } -func NewService(name string, maxTaskNum int32, timeout int64, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service { +func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager) *Service { + config.ApiURL = strings.TrimRight(config.ApiURL, "/") return &Service{ - name: name, - config: config, - httpClient: req.C(), - taskQueue: taskQueue, - notifyQueue: notifyQueue, - db: db, - uploadManager: manager, - taskTimeout: timeout, - maxHandleTaskNum: maxTaskNum, - taskStartTimes: make(map[int]time.Time), + name: name, + config: config, + httpClient: req.C(), + taskQueue: taskQueue, + notifyQueue: notifyQueue, + db: db, + uploadManager: manager, } } func (s *Service) Run() { for { - s.checkTasks() - if !s.canHandleTask() { - // current service is full, can not handle more task - // waiting for running task finish - time.Sleep(time.Second * 3) - continue - } - var task types.SdTask err := s.taskQueue.LPop(&task) if err != nil { @@ -74,239 +55,135 @@ func (s *Service) Run() { "progress": -1, "err_msg": err.Error(), }) - // release task num - atomic.AddInt32(&s.handledTaskNum, -1) // 通知前端,任务失败 s.notifyQueue.RPush(task.UserId) continue } - - // lock the task until the execute timeout - s.taskStartTimes[task.Id] = time.Now() - atomic.AddInt32(&s.handledTaskNum, 1) } } -// check if current service instance can handle more task -func (s *Service) canHandleTask() bool { - handledNum := atomic.LoadInt32(&s.handledTaskNum) - return handledNum < s.maxHandleTaskNum +// Txt2ImgReq 文生图请求实体 +type Txt2ImgReq struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt"` + Seed int64 `json:"seed"` + Steps int `json:"steps"` + CfgScale float32 `json:"cfg_scale"` + Width int `json:"width"` + Height int `json:"height"` + SamplerName string `json:"sampler_name"` + EnableHr bool `json:"enable_hr,omitempty"` + HrScale int `json:"hr_scale,omitempty"` + HrUpscaler string `json:"hr_upscaler,omitempty"` + HrSecondPassSteps int `json:"hr_second_pass_steps,omitempty"` + DenoisingStrength float32 `json:"denoising_strength,omitempty"` + ForceTaskId string `json:"force_task_id,omitempty"` } -// remove the expired tasks -func (s *Service) checkTasks() { - for k, t := range s.taskStartTimes { - if time.Now().Unix()-t.Unix() > s.taskTimeout { - delete(s.taskStartTimes, k) - atomic.AddInt32(&s.handledTaskNum, -1) - // delete task from database - s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100") - } - } +// Txt2ImgResp 文生图响应实体 +type Txt2ImgResp struct { + Images []string `json:"images"` + Parameters struct { + } `json:"parameters"` + Info string `json:"info"` +} + +// TaskProgressResp 任务进度响应实体 +type TaskProgressResp struct { + Progress float64 `json:"progress"` + EtaRelative float64 `json:"eta_relative"` + CurrentImage string `json:"current_image"` } // Txt2Img 文生图 API func (s *Service) Txt2Img(task types.SdTask) error { - var taskInfo TaskInfo - bytes, err := os.ReadFile(s.config.Txt2ImgJsonPath) - if err != nil { - return fmt.Errorf("error with load text2img json template file: %s", err.Error()) + body := Txt2ImgReq{ + Prompt: task.Params.Prompt, + NegativePrompt: task.Params.NegativePrompt, + Steps: task.Params.Steps, + CfgScale: task.Params.CfgScale, + Width: task.Params.Width, + Height: task.Params.Height, + SamplerName: task.Params.Sampler, } - - err = json.Unmarshal(bytes, &taskInfo) - if err != nil { - return fmt.Errorf("error with decode json params: %s", err.Error()) + if task.Params.Seed > 0 { + body.Seed = task.Params.Seed } - - data := taskInfo.Data - params := task.Params - data[ParamKeys["task_id"]] = params.TaskId - data[ParamKeys["prompt"]] = params.Prompt - data[ParamKeys["negative_prompt"]] = params.NegativePrompt - data[ParamKeys["steps"]] = params.Steps - data[ParamKeys["sampler"]] = params.Sampler - // @fix bug: 有些 stable diffusion 没有面部修复功能 - //data[ParamKeys["face_fix"]] = params.FaceFix - data[ParamKeys["cfg_scale"]] = params.CfgScale - data[ParamKeys["seed"]] = params.Seed - data[ParamKeys["height"]] = params.Height - data[ParamKeys["width"]] = params.Width - data[ParamKeys["hd_fix"]] = params.HdFix - data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate - data[ParamKeys["hd_scale"]] = params.HdScale - data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg - data[ParamKeys["hd_sample_num"]] = params.HdSteps - - taskInfo.SessionId = task.SessionId - taskInfo.TaskId = params.TaskId - taskInfo.Data = data - taskInfo.JobId = task.Id - taskInfo.UserId = uint(task.UserId) + if task.Params.HdFix { + body.EnableHr = true + body.HrScale = task.Params.HdScale + body.HrUpscaler = task.Params.HdScaleAlg + body.HrSecondPassSteps = task.Params.HdSteps + body.DenoisingStrength = task.Params.HdRedrawRate + } + var res Txt2ImgResp + var errChan = make(chan error) + apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL) + logger.Debugf("send image request to %s", apiURL) go func() { - s.runTask(taskInfo, s.httpClient) - }() - return nil -} - -// 执行任务 -func (s *Service) runTask(taskInfo TaskInfo, client *req.Client) { - body := map[string]any{ - "data": taskInfo.Data, - "event_data": taskInfo.EventData, - "fn_index": taskInfo.FnIndex, - "session_hash": taskInfo.SessionHash, - } - var result = make(chan CBReq) - go func() { - var res struct { - Data []interface{} `json:"data"` - IsGenerating bool `json:"is_generating"` - Duration float64 `json:"duration"` - AverageDuration float64 `json:"average_duration"` - } - var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId} - response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(s.config.ApiURL + "/run/predict") + response, err := s.httpClient.R().SetBody(body).SetSuccessResult(&res).Post(apiURL) if err != nil { - cbReq.Message = "error with send request: " + err.Error() - cbReq.Success = false - result <- cbReq + errChan <- err return } - if response.IsErrorState() { - bytes, _ := io.ReadAll(response.Body) - cbReq.Message = "error http status code: " + string(bytes) - cbReq.Success = false - result <- cbReq + errChan <- fmt.Errorf("error http code status: %v", response.Status) return } - var images []struct { - Name string `json:"name"` - Data interface{} `json:"data"` - IsFile bool `json:"is_file"` - } - err = utils.ForceCovert(res.Data[0], &images) + // 保存 Base64 图片 + imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0]) if err != nil { - cbReq.Message = "error with decode image:" + err.Error() - cbReq.Success = false - result <- cbReq + errChan <- fmt.Errorf("error with upload image: %v", err) return } - - var info map[string]any - err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info) + // 获取绘画真实的 seed + var info map[string]interface{} + err = utils.JsonDecode(res.Info, &info) if err != nil { - logger.Error(res.Data) - cbReq.Message = "error with decode image url:" + err.Error() - cbReq.Success = false - result <- cbReq + errChan <- fmt.Errorf("error with decode task response: %v", err) return } - - // 获取真实的 seed 值 - cbReq.ImageName = images[0].Name - seed, _ := strconv.ParseInt(utils.InterfaceToString(info["seed"]), 10, 64) - cbReq.Seed = seed - cbReq.Success = true - cbReq.Progress = 100 - result <- cbReq - close(result) - + task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1)) + s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)}) + errChan <- nil }() for { select { - case value := <-result: - s.callback(value) - return + case err := <-errChan: // 任务完成 + if err != nil { + return err + } + s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) + s.notifyQueue.RPush(task.UserId) + return nil default: - var progressReq = map[string]any{ - "id_task": taskInfo.TaskId, - "id_live_preview": 1, + err, resp := s.checkTaskProgress() + // 更新任务进度 + if err == nil && resp.Progress > 0 { + logger.Debugf("Check task progress: %+v", resp.Progress) + s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) + // 发送更新状态信号 + s.notifyQueue.RPush(task.UserId) } - - var progressRes struct { - Active bool `json:"active"` - Queued bool `json:"queued"` - Completed bool `json:"completed"` - Progress float64 `json:"progress"` - Eta float64 `json:"eta"` - LivePreview string `json:"live_preview"` - IDLivePreview int `json:"id_live_preview"` - TextInfo interface{} `json:"textinfo"` - } - response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(s.config.ApiURL + "/internal/progress") - var cbReq = CBReq{UserId: taskInfo.UserId, TaskId: taskInfo.TaskId, Success: true, JobId: taskInfo.JobId, SessionId: taskInfo.SessionId} - if err != nil { // TODO: 这里可以考虑设置失败重试次数 - logger.Error(err) - return - } - - if response.IsErrorState() { - bytes, _ := io.ReadAll(response.Body) - logger.Error(string(bytes)) - return - } - - cbReq.ImageData = progressRes.LivePreview - cbReq.Progress = int(progressRes.Progress * 100) - s.callback(cbReq) time.Sleep(time.Second) } } + } -func (s *Service) callback(data CBReq) { - // release task num - atomic.AddInt32(&s.handledTaskNum, -1) - if data.Success { // 任务成功 - var job model.SdJob - res := s.db.Where("id = ?", data.JobId).First(&job) - if res.Error != nil { - logger.Warn("非法任务:", res.Error) - return - } - // 更新任务进度 - job.Progress = data.Progress - // 更新任务 seed - var params types.SdTaskParams - err := utils.JsonDecode(job.Params, ¶ms) - if err != nil { - logger.Error("任务解析失败:", err) - return - } - - params.Seed = data.Seed - if data.ImageName != "" { // 下载图片 - job.ImgURL = fmt.Sprintf("%s/file=%s", s.config.ApiURL, data.ImageName) - if data.Progress == 100 { - imageURL, err := s.uploadManager.GetUploadHandler().PutImg(job.ImgURL, false) - if err != nil { - logger.Error("error with download img: ", err.Error()) - return - } - job.ImgURL = imageURL - } - } - - job.Params = utils.JsonEncode(params) - res = s.db.Updates(&job) - if res.Error != nil { - logger.Error("error with update job: ", res.Error) - return - } - - logger.Debugf("绘图进度:%d", data.Progress) - } else { // 任务失败 - logger.Error("任务执行失败:", data.Message) - // update the task progress - s.db.Model(&model.SdJob{Id: uint(data.JobId)}).UpdateColumns(map[string]interface{}{ - "progress": -1, - "err_msg": data.Message, - }) +// 执行任务 +func (s *Service) checkTaskProgress() (error, *TaskProgressResp) { + apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL) + var res TaskProgressResp + response, err := s.httpClient.R().SetSuccessResult(&res).Get(apiURL) + if err != nil { + return err, nil + } + if response.IsErrorState() { + return fmt.Errorf("error http code status: %v", response.Status), nil } - // 发送更新状态信号 - s.notifyQueue.RPush(data.UserId) + return nil, &res } diff --git a/api/test/test.go b/api/test/test.go index 1ca737cf..79058077 100644 --- a/api/test/test.go +++ b/api/test/test.go @@ -1,21 +1,5 @@ package main -import ( - "chatplus/utils" - "fmt" -) - -type Person struct { - Name string - Age int -} - -type Student struct { - Person - School string -} - func main() { - fmt.Println(utils.RandString(64)) }