geekai/api/service/sd/client.go

170 lines
4.4 KiB
Go

package sd
import (
"chatplus/core/types"
"chatplus/utils"
"fmt"
"github.com/imroc/req/v3"
"io"
"time"
)
type Client struct {
httpClient *req.Client
config *types.StableDiffusionConfig
}
func NewSdClient(config *types.AppConfig) *Client {
return &Client{
config: &config.SdConfig,
httpClient: req.C(),
}
}
func (c *Client) Txt2Img(params types.SdTaskParams) error {
var data []interface{}
err := utils.JsonDecode(Text2ImgParamTemplate, &data)
if err != nil {
return err
}
data[ParamKeys["task_id"]] = params.TaskId
data[ParamKeys["prompt"]] = params.Prompt
data[ParamKeys["negative_prompt"]] = params.NegativePrompt
data[ParamKeys["steps"]] = params.Steps
data[ParamKeys["sampler"]] = params.Sampler
data[ParamKeys["face_fix"]] = params.FaceFix
data[ParamKeys["cfg_scale"]] = params.CfgScale
data[ParamKeys["seed"]] = params.Seed
data[ParamKeys["height"]] = params.Height
data[ParamKeys["width"]] = params.Width
data[ParamKeys["hd_fix"]] = params.HdFix
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
data[ParamKeys["hd_scale"]] = params.HdScale
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
data[ParamKeys["hd_sample_num"]] = params.HdSampleNum
task := TaskInfo{
TaskId: params.TaskId,
Data: data,
EventData: nil,
FnIndex: 494,
SessionHash: "ycaxgzm9ah",
}
go func() {
c.runTask(task, c.httpClient)
}()
return nil
}
func (c *Client) runTask(taskInfo TaskInfo, client *req.Client) {
body := map[string]any{
"data": taskInfo.Data,
"event_data": taskInfo.EventData,
"fn_index": taskInfo.FnIndex,
"session_hash": taskInfo.SessionHash,
}
var result = make(chan CBReq)
go func() {
var res struct {
Data []interface{} `json:"data"`
IsGenerating bool `json:"is_generating"`
Duration float64 `json:"duration"`
AverageDuration float64 `json:"average_duration"`
}
var cbReq = CBReq{TaskId: taskInfo.TaskId}
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(c.config.ApiURL + "/run/predict")
if err != nil {
cbReq.Message = "error with send request: " + err.Error()
cbReq.Success = false
result <- cbReq
return
}
if response.IsErrorState() {
bytes, _ := io.ReadAll(response.Body)
cbReq.Message = "error http status code: " + string(bytes)
cbReq.Success = false
result <- cbReq
return
}
var images []struct {
Name string `json:"name"`
Data interface{} `json:"data"`
IsFile bool `json:"is_file"`
}
err = utils.ForceCovert(res.Data[0], &images)
if err != nil {
cbReq.Message = "error with decode image:" + err.Error()
cbReq.Success = false
result <- cbReq
return
}
var info map[string]any
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
if err != nil {
cbReq.Message = err.Error()
cbReq.Success = false
result <- cbReq
return
}
//for k, v := range info {
// fmt.Println(k, " => ", v)
//}
cbReq.ImageName = images[0].Name
cbReq.Seed = utils.InterfaceToString(info["seed"])
cbReq.Success = true
cbReq.Progress = 100
result <- cbReq
close(result)
}()
for {
select {
case value := <-result:
if value.Success {
logger.Infof("%s/file=%s", c.config.ApiURL, value.ImageName)
}
return
default:
var progressReq = map[string]any{
"id_task": taskInfo.TaskId,
"id_live_preview": 1,
}
var progressRes struct {
Active bool `json:"active"`
Queued bool `json:"queued"`
Completed bool `json:"completed"`
Progress float64 `json:"progress"`
Eta float64 `json:"eta"`
LivePreview string `json:"live_preview"`
IDLivePreview int `json:"id_live_preview"`
TextInfo interface{} `json:"textinfo"`
}
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(c.config.ApiURL + "/internal/progress")
var cbReq = CBReq{TaskId: taskInfo.TaskId, Success: true}
if err != nil { // TODO: 这里可以考虑设置失败重试次数
logger.Error(err)
return
}
if response.IsErrorState() {
bytes, _ := io.ReadAll(response.Body)
logger.Error(string(bytes))
return
}
cbReq.ImageData = progressRes.LivePreview
cbReq.Progress = int(progressRes.Progress * 100)
fmt.Println("Progress: ", progressRes.Progress)
fmt.Println("Image: ", progressRes.LivePreview)
time.Sleep(time.Second)
}
}
}