feat: support midjourney --cref and --sref for role consistency

This commit is contained in:
RockYang
2024-04-02 14:59:53 +08:00
parent 56c225bf20
commit 3f1ad4b7dc
10 changed files with 247 additions and 53 deletions

View File

@@ -25,6 +25,7 @@ type MjTask struct {
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
Params string `json:"full_prompt"`
Index int `json:"index,omitempty"`
MessageId string `json:"message_id,omitempty"`
MessageHash string `json:"message_hash,omitempty"`

View File

@@ -98,7 +98,10 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
ImgArr []string `json:"img_arr"`
Tile bool `json:"tile"`
Quality float32 `json:"quality"`
Weight float32 `json:"weight"`
Iw float32 `json:"iw"`
CRef string `json:"cref"` //生成角色一致的图像
SRef string `json:"sref"` //生成风格一致的图像
Cw int `json:"cw"` // 参考程度
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
@@ -108,41 +111,53 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return
}
var prompt = data.Prompt
if data.Rate != "" && !strings.Contains(prompt, "--ar") {
prompt += " --ar " + data.Rate
var params = ""
if data.Rate != "" && !strings.Contains(params, "--ar") {
params += " --ar " + data.Rate
}
if data.Seed > 0 && !strings.Contains(prompt, "--seed") {
prompt += fmt.Sprintf(" --seed %d", data.Seed)
if data.Seed > 0 && !strings.Contains(params, "--seed") {
params += fmt.Sprintf(" --seed %d", data.Seed)
}
if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") {
prompt += fmt.Sprintf(" --s %d", data.Stylize)
if data.Stylize > 0 && !strings.Contains(params, "--s") && !strings.Contains(params, "--stylize") {
params += fmt.Sprintf(" --s %d", data.Stylize)
}
if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
prompt += fmt.Sprintf(" --c %d", data.Chaos)
if data.Chaos > 0 && !strings.Contains(params, "--c") && !strings.Contains(params, "--chaos") {
params += fmt.Sprintf(" --c %d", data.Chaos)
}
if data.Weight > 0 {
prompt += fmt.Sprintf(" --iw %f", data.Weight)
if len(data.ImgArr) > 0 && data.Iw > 0 {
params += fmt.Sprintf(" --iw %f", data.Iw)
}
if data.Raw {
prompt += " --style raw"
params += " --style raw"
}
if data.Quality > 0 {
prompt += fmt.Sprintf(" --q %.2f", data.Quality)
params += fmt.Sprintf(" --q %.2f", data.Quality)
}
if data.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", data.NegPrompt)
params += fmt.Sprintf(" --no %s", data.NegPrompt)
}
if data.Tile {
prompt += " --tile "
params += " --tile "
}
if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
prompt += fmt.Sprintf(" %s", data.Model)
if data.CRef != "" {
params += fmt.Sprintf(" --cref %s", data.CRef)
if data.Cw > 0 {
params += fmt.Sprintf(" --cw %d", data.Cw)
} else {
params += " --cw 100"
}
}
if data.SRef != "" {
params += fmt.Sprintf(" --sref %s", data.CRef)
}
if data.Model != "" && !strings.Contains(params, "--v") && !strings.Contains(params, "--niji") {
params += fmt.Sprintf(" %s", data.Model)
}
// 处理融图和换脸的提示词
if data.TaskType == types.TaskSwapFace.String() || data.TaskType == types.TaskBlend.String() {
prompt = fmt.Sprintf("%s:%s", data.TaskType, strings.Join(data.ImgArr, ","))
params = fmt.Sprintf("%s:%s", data.TaskType, strings.Join(data.ImgArr, ","))
}
// 如果本地图片上传的是相对地址,处理成绝对地址
@@ -165,7 +180,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
UserId: userId,
TaskId: taskId,
Progress: 0,
Prompt: prompt,
Prompt: fmt.Sprintf("%s %s", data.Prompt, params),
Power: h.App.SysConfig.MjPower,
CreatedAt: time.Now(),
}
@@ -188,7 +203,8 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
TaskId: taskId,
SessionId: data.SessionId,
Type: types.TaskType(data.TaskType),
Prompt: prompt,
Prompt: data.Prompt,
Params: params,
UserId: userId,
ImgArr: data.ImgArr,
})

View File

@@ -53,6 +53,10 @@ func (l *AppLifecycle) OnStop(context.Context) error {
return nil
}
func NewAppLifeCycle() *AppLifecycle {
return &AppLifecycle{}
}
func main() {
configFile := os.Getenv("CONFIG_FILE")
if configFile == "" {
@@ -432,11 +436,14 @@ func main() {
group.GET("list", h.List)
}),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
err := s.Run(db)
if err != nil {
log.Fatal(err)
}
go func() {
err := s.Run(db)
if err != nil {
log.Fatal(err)
}
}()
}),
fx.Provide(NewAppLifeCycle),
// 注册生命周期回调函数
fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) {
lifecycle.Append(fx.Hook{

View File

@@ -26,7 +26,7 @@ func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: task.Prompt,
Prompt: fmt.Sprintf("%s %s", task.Prompt, task.Params),
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码

View File

@@ -23,7 +23,7 @@ func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
body := ImageReq{
Prompt: task.Prompt,
Prompt: fmt.Sprintf("%s %s", task.Prompt, task.Params),
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
@@ -46,8 +46,6 @@ func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
all, err := io.ReadAll(r.Body)
logger.Info(string(all))
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}

View File

@@ -68,7 +68,7 @@ func (s *Service) Run() {
}
// 如果是 mj-proxy 则自动翻译提示词
if utils.HasChinese(task.Prompt) && strings.HasPrefix(s.Name, "mj-proxy-service") {
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt))
if err == nil {
task.Prompt = content