mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-11 11:43:43 +08:00
upgrade to v4.0.4
This commit is contained in:
@@ -60,16 +60,16 @@ func (p *ServicePool) CheckTaskNotify() {
|
||||
go func() {
|
||||
logger.Info("Running Stable-Diffusion task notify checking ...")
|
||||
for {
|
||||
var userId uint
|
||||
err := p.notifyQueue.LPop(&userId)
|
||||
var message NotifyMessage
|
||||
err := p.notifyQueue.LPop(&message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
client := p.Clients.Get(userId)
|
||||
client := p.Clients.Get(uint(message.UserId))
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
err = client.Send([]byte("Task Updated"))
|
||||
err = client.Send([]byte(message.Message))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func (p *ServicePool) CheckTaskStatus() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -8,10 +8,11 @@ import (
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SD 绘画服务
|
||||
@@ -80,7 +81,7 @@ func (s *Service) Run() {
|
||||
"err_msg": err.Error(),
|
||||
})
|
||||
// 通知前端,任务失败
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -145,8 +146,13 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
var errChan = make(chan error)
|
||||
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
|
||||
logger.Debugf("send image request to %s", apiURL)
|
||||
// send a request to sd api endpoint
|
||||
go func() {
|
||||
response, err := s.httpClient.R().SetBody(body).SetSuccessResult(&res).Post(apiURL)
|
||||
response, err := s.httpClient.R().
|
||||
SetHeader("Authorization", s.config.ApiKey).
|
||||
SetBody(body).
|
||||
SetSuccessResult(&res).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
@@ -174,14 +180,17 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
errChan <- nil
|
||||
}()
|
||||
|
||||
// waiting for task finish
|
||||
for {
|
||||
select {
|
||||
case err := <-errChan: // 任务完成
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// task finished
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished})
|
||||
// 从 leveldb 中删除预览图片数据
|
||||
_ = s.leveldb.Delete(task.Params.TaskId)
|
||||
return nil
|
||||
@@ -191,7 +200,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
if err == nil && resp.Progress > 0 {
|
||||
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
|
||||
// 发送更新状态信号
|
||||
s.notifyQueue.RPush(task.UserId)
|
||||
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running})
|
||||
// 保存预览图片数据
|
||||
if resp.CurrentImage != "" {
|
||||
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)
|
||||
@@ -207,7 +216,10 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
|
||||
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
|
||||
var res TaskProgressResp
|
||||
response, err := s.httpClient.R().SetSuccessResult(&res).Get(apiURL)
|
||||
response, err := s.httpClient.R().
|
||||
SetHeader("Authorization", s.config.ApiKey).
|
||||
SetSuccessResult(&res).
|
||||
Get(apiURL)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
|
||||
@@ -4,44 +4,14 @@ import logger2 "chatplus/logger"
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type TaskInfo struct {
|
||||
UserId uint `json:"user_id"`
|
||||
SessionId string `json:"session_id"`
|
||||
JobId int `json:"job_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
Data []interface{} `json:"data"`
|
||||
EventData interface{} `json:"event_data"`
|
||||
FnIndex int `json:"fn_index"`
|
||||
SessionHash string `json:"session_hash"`
|
||||
type NotifyMessage struct {
|
||||
UserId int `json:"user_id"`
|
||||
JobId int `json:"job_id"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type CBReq struct {
|
||||
UserId uint
|
||||
SessionId string
|
||||
JobId int
|
||||
TaskId string
|
||||
ImageName string
|
||||
ImageData string
|
||||
Progress int
|
||||
Seed int64
|
||||
Success bool
|
||||
Message string
|
||||
}
|
||||
|
||||
var ParamKeys = map[string]int{
|
||||
"task_id": 0,
|
||||
"prompt": 1,
|
||||
"negative_prompt": 2,
|
||||
"steps": 4,
|
||||
"sampler": 5,
|
||||
"face_fix": 7, // 面部修复
|
||||
"cfg_scale": 8,
|
||||
"seed": 27,
|
||||
"height": 10,
|
||||
"width": 9,
|
||||
"hd_fix": 11,
|
||||
"hd_redraw_rate": 12, //高清修复重绘幅度
|
||||
"hd_scale": 13, // 高清修复放大倍数
|
||||
"hd_scale_alg": 14, // 高清修复放大算法
|
||||
"hd_sample_num": 15, // 高清修复采样次数
|
||||
}
|
||||
const (
|
||||
Running = "RUNNING"
|
||||
Finished = "FINISH"
|
||||
Failed = "FAIL"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user