mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
170 lines
4.4 KiB
Go
170 lines
4.4 KiB
Go
package sd
|
|
|
|
import (
|
|
"chatplus/core/types"
|
|
"chatplus/utils"
|
|
"fmt"
|
|
"github.com/imroc/req/v3"
|
|
"io"
|
|
"time"
|
|
)
|
|
|
|
type Client struct {
|
|
httpClient *req.Client
|
|
config *types.StableDiffusionConfig
|
|
}
|
|
|
|
func NewSdClient(config *types.AppConfig) *Client {
|
|
return &Client{
|
|
config: &config.SdConfig,
|
|
httpClient: req.C(),
|
|
}
|
|
}
|
|
|
|
func (c *Client) Txt2Img(params types.SdTaskParams) error {
|
|
var data []interface{}
|
|
err := utils.JsonDecode(Text2ImgParamTemplate, &data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
data[ParamKeys["task_id"]] = params.TaskId
|
|
data[ParamKeys["prompt"]] = params.Prompt
|
|
data[ParamKeys["negative_prompt"]] = params.NegativePrompt
|
|
data[ParamKeys["steps"]] = params.Steps
|
|
data[ParamKeys["sampler"]] = params.Sampler
|
|
data[ParamKeys["face_fix"]] = params.FaceFix
|
|
data[ParamKeys["cfg_scale"]] = params.CfgScale
|
|
data[ParamKeys["seed"]] = params.Seed
|
|
data[ParamKeys["height"]] = params.Height
|
|
data[ParamKeys["width"]] = params.Width
|
|
data[ParamKeys["hd_fix"]] = params.HdFix
|
|
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
|
|
data[ParamKeys["hd_scale"]] = params.HdScale
|
|
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
|
|
data[ParamKeys["hd_sample_num"]] = params.HdSampleNum
|
|
task := TaskInfo{
|
|
TaskId: params.TaskId,
|
|
Data: data,
|
|
EventData: nil,
|
|
FnIndex: 494,
|
|
SessionHash: "ycaxgzm9ah",
|
|
}
|
|
|
|
go func() {
|
|
c.runTask(task, c.httpClient)
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) runTask(taskInfo TaskInfo, client *req.Client) {
|
|
body := map[string]any{
|
|
"data": taskInfo.Data,
|
|
"event_data": taskInfo.EventData,
|
|
"fn_index": taskInfo.FnIndex,
|
|
"session_hash": taskInfo.SessionHash,
|
|
}
|
|
|
|
var result = make(chan CBReq)
|
|
go func() {
|
|
var res struct {
|
|
Data []interface{} `json:"data"`
|
|
IsGenerating bool `json:"is_generating"`
|
|
Duration float64 `json:"duration"`
|
|
AverageDuration float64 `json:"average_duration"`
|
|
}
|
|
var cbReq = CBReq{TaskId: taskInfo.TaskId}
|
|
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(c.config.ApiURL + "/run/predict")
|
|
if err != nil {
|
|
cbReq.Message = "error with send request: " + err.Error()
|
|
cbReq.Success = false
|
|
result <- cbReq
|
|
return
|
|
}
|
|
|
|
if response.IsErrorState() {
|
|
bytes, _ := io.ReadAll(response.Body)
|
|
cbReq.Message = "error http status code: " + string(bytes)
|
|
cbReq.Success = false
|
|
result <- cbReq
|
|
return
|
|
}
|
|
|
|
var images []struct {
|
|
Name string `json:"name"`
|
|
Data interface{} `json:"data"`
|
|
IsFile bool `json:"is_file"`
|
|
}
|
|
err = utils.ForceCovert(res.Data[0], &images)
|
|
if err != nil {
|
|
cbReq.Message = "error with decode image:" + err.Error()
|
|
cbReq.Success = false
|
|
result <- cbReq
|
|
return
|
|
}
|
|
|
|
var info map[string]any
|
|
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
|
|
if err != nil {
|
|
cbReq.Message = err.Error()
|
|
cbReq.Success = false
|
|
result <- cbReq
|
|
return
|
|
}
|
|
|
|
//for k, v := range info {
|
|
// fmt.Println(k, " => ", v)
|
|
//}
|
|
cbReq.ImageName = images[0].Name
|
|
cbReq.Seed = utils.InterfaceToString(info["seed"])
|
|
cbReq.Success = true
|
|
cbReq.Progress = 100
|
|
result <- cbReq
|
|
close(result)
|
|
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case value := <-result:
|
|
if value.Success {
|
|
logger.Infof("%s/file=%s", c.config.ApiURL, value.ImageName)
|
|
}
|
|
return
|
|
default:
|
|
var progressReq = map[string]any{
|
|
"id_task": taskInfo.TaskId,
|
|
"id_live_preview": 1,
|
|
}
|
|
|
|
var progressRes struct {
|
|
Active bool `json:"active"`
|
|
Queued bool `json:"queued"`
|
|
Completed bool `json:"completed"`
|
|
Progress float64 `json:"progress"`
|
|
Eta float64 `json:"eta"`
|
|
LivePreview string `json:"live_preview"`
|
|
IDLivePreview int `json:"id_live_preview"`
|
|
TextInfo interface{} `json:"textinfo"`
|
|
}
|
|
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(c.config.ApiURL + "/internal/progress")
|
|
var cbReq = CBReq{TaskId: taskInfo.TaskId, Success: true}
|
|
if err != nil { // TODO: 这里可以考虑设置失败重试次数
|
|
logger.Error(err)
|
|
return
|
|
}
|
|
|
|
if response.IsErrorState() {
|
|
bytes, _ := io.ReadAll(response.Body)
|
|
logger.Error(string(bytes))
|
|
return
|
|
}
|
|
|
|
cbReq.ImageData = progressRes.LivePreview
|
|
cbReq.Progress = int(progressRes.Progress * 100)
|
|
fmt.Println("Progress: ", progressRes.Progress)
|
|
fmt.Println("Image: ", progressRes.LivePreview)
|
|
time.Sleep(time.Second)
|
|
}
|
|
}
|
|
}
|