feat: stable diffusion image drawing on mobile is ready

This commit is contained in:
RockYang
2024-04-03 18:13:48 +08:00
parent 9b7ee538c4
commit 0355c37bef
19 changed files with 1457 additions and 545 deletions

View File

@@ -25,6 +25,7 @@ type MjTask struct {
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
NegPrompt string `json:"neg_prompt,omitempty"`
Params string `json:"full_prompt"`
Index int `json:"index,omitempty"`
MessageId string `json:"message_id,omitempty"`
@@ -42,19 +43,19 @@ type SdTask struct {
}
type SdTaskParams struct {
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
NegativePrompt string `json:"negative_prompt"` // 反向提示词
Steps int `json:"steps"` // 迭代步数默认20
Sampler string `json:"sampler"` // 采样器
FaceFix bool `json:"face_fix"` // 面部修复
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
Seed int64 `json:"seed"` // 随机数种子
Height int `json:"height"`
Width int `json:"width"`
HdFix bool `json:"hd_fix"` // 启用高清修复
HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
HdScale int `json:"hd_scale"` // 放大倍数
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
HdSteps int `json:"hd_steps"` // 高清修复迭代步数
TaskId string `json:"task_id"`
Prompt string `json:"prompt"` // 提示词
NegPrompt string `json:"neg_prompt"` // 反向提示词
Steps int `json:"steps"` // 迭代步数默认20
Sampler string `json:"sampler"` // 采样器
FaceFix bool `json:"face_fix"` // 面部修复
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
Seed int64 `json:"seed"` // 随机数种子
Height int `json:"height"`
Width int `json:"width"`
HdFix bool `json:"hd_fix"` // 启用高清修复
HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
HdScale int `json:"hd_scale"` // 放大倍数
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
HdSteps int `json:"hd_steps"` // 高清修复迭代步数
}

View File

@@ -133,9 +133,6 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
if data.Quality > 0 {
params += fmt.Sprintf(" --q %.2f", data.Quality)
}
if data.NegPrompt != "" {
params += fmt.Sprintf(" --no %s", data.NegPrompt)
}
if data.Tile {
params += " --tile "
}
@@ -204,6 +201,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
SessionId: data.SessionId,
Type: types.TaskType(data.TaskType),
Prompt: data.Prompt,
NegPrompt: data.NegPrompt,
Params: params,
UserId: userId,
ImgArr: data.ImgArr,

View File

@@ -424,27 +424,21 @@ func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
var opt string
var power int
if user.Vip { // 已经是 VIP 用户
if remark.Days > 0 { // 只延期 VIP不增加调用次数
if remark.Days > 0 { // VIP 充值
if user.ExpiredTime >= time.Now().Unix() {
user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
} else { // 充值点卡,直接增加次数即可
user.Power += remark.Power
opt = "点卡充值"
power = remark.Power
}
} else { // 非 VIP 用户
if remark.Days > 0 { // vip 套餐days > 0, power == 0
opt = "VIP充值VIP 没到期,只延期不增加算力"
} else {
user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
user.Power += h.App.SysConfig.VipMonthPower
user.Vip = true
opt = "VIP充值"
power = h.App.SysConfig.VipMonthPower
} else { //点卡days == 0, calls > 0
user.Power += remark.Power
opt = "点卡充值"
power = remark.Power
opt = "VIP充值"
}
user.Vip = true
} else { // 充值点卡,直接增加次数即可
user.Power += remark.Power
opt = "点卡充值"
power = remark.Power
}
// 更新用户信息

View File

@@ -127,21 +127,21 @@ func (h *SdJobHandler) Image(c *gin.Context) {
return
}
params := types.SdTaskParams{
TaskId: taskId,
Prompt: data.Prompt,
NegativePrompt: data.NegativePrompt,
Steps: data.Steps,
Sampler: data.Sampler,
FaceFix: data.FaceFix,
CfgScale: data.CfgScale,
Seed: data.Seed,
Height: data.Height,
Width: data.Width,
HdFix: data.HdFix,
HdRedrawRate: data.HdRedrawRate,
HdScale: data.HdScale,
HdScaleAlg: data.HdScaleAlg,
HdSteps: data.HdSteps,
TaskId: taskId,
Prompt: data.Prompt,
NegPrompt: data.NegPrompt,
Steps: data.Steps,
Sampler: data.Sampler,
FaceFix: data.FaceFix,
CfgScale: data.CfgScale,
Seed: data.Seed,
Height: data.Height,
Width: data.Width,
HdFix: data.HdFix,
HdRedrawRate: data.HdRedrawRate,
HdScale: data.HdScale,
HdScaleAlg: data.HdScaleAlg,
HdSteps: data.HdSteps,
}
job := model.SdJob{

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"github.com/imroc/req/v3"
"io"
"time"
"github.com/gin-gonic/gin"
)
@@ -16,17 +17,26 @@ import (
type PlusClient struct {
Config types.MjPlusConfig
apiURL string
client *req.Client
}
func NewPlusClient(config types.MjPlusConfig) *PlusClient {
return &PlusClient{Config: config, apiURL: config.ApiURL}
return &PlusClient{
Config: config,
apiURL: config.ApiURL,
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
}
}
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: fmt.Sprintf("%s %s", task.Prompt, task.Params),
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
@@ -42,7 +52,7 @@ func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
@@ -81,7 +91,7 @@ func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
@@ -130,7 +140,7 @@ func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
r, err := c.client.SetTimeout(time.Minute).R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
@@ -156,7 +166,7 @@ func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
r, err := c.client.R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
@@ -202,7 +212,7 @@ func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)

View File

@@ -22,8 +22,12 @@ func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
Prompt: fmt.Sprintf("%s %s", task.Prompt, task.Params),
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码

View File

@@ -67,7 +67,7 @@ func (s *Service) Run() {
continue
}
// 如果是 mj-proxy 则自动翻译提示词
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt))
if err == nil {
@@ -76,6 +76,15 @@ func (s *Service) Run() {
logger.Warnf("error with translate prompt: %v", err)
}
}
// translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt))
if err == nil {
task.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
var job model.MidJourneyJob
tx := s.db.Where("id = ?", task.Id).First(&job)

View File

@@ -49,11 +49,24 @@ func (s *Service) Run() {
logger.Errorf("taking task with error: %v", err)
continue
}
// 翻译提示词
// translate prompt
if utils.HasChinese(task.Params.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt))
if err == nil {
task.Params.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
// translate negative prompt
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt))
if err == nil {
task.Params.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
@@ -110,7 +123,7 @@ type TaskProgressResp struct {
func (s *Service) Txt2Img(task types.SdTask) error {
body := Txt2ImgReq{
Prompt: task.Params.Prompt,
NegativePrompt: task.Params.NegativePrompt,
NegativePrompt: task.Params.NegPrompt,
Steps: task.Params.Steps,
CfgScale: task.Params.CfgScale,
Width: task.Params.Width,