mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
214 lines
5.7 KiB
Go
214 lines
5.7 KiB
Go
package mj
|
||
|
||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||
// * Use of this source code is governed by a Apache-2.0 license
|
||
// * that can be found in the LICENSE file.
|
||
// * @Author yangjian102621@163.com
|
||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||
|
||
import (
|
||
"fmt"
|
||
"geekai/core/types"
|
||
"geekai/service"
|
||
"geekai/service/sd"
|
||
"geekai/store"
|
||
"geekai/store/model"
|
||
"geekai/utils"
|
||
"strings"
|
||
"time"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// Service MJ 绘画服务
|
||
type Service struct {
|
||
Name string // service Name
|
||
Client Client // MJ Client
|
||
taskQueue *store.RedisQueue
|
||
notifyQueue *store.RedisQueue
|
||
db *gorm.DB
|
||
running bool
|
||
retryCount map[uint]int
|
||
}
|
||
|
||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
|
||
return &Service{
|
||
Name: name,
|
||
db: db,
|
||
taskQueue: taskQueue,
|
||
notifyQueue: notifyQueue,
|
||
Client: cli,
|
||
running: true,
|
||
retryCount: make(map[uint]int),
|
||
}
|
||
}
|
||
|
||
const failedProgress = 101
|
||
|
||
func (s *Service) Run() {
|
||
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
|
||
for s.running {
|
||
var task types.MjTask
|
||
err := s.taskQueue.LPop(&task)
|
||
if err != nil {
|
||
logger.Errorf("taking task with error: %v", err)
|
||
continue
|
||
}
|
||
|
||
// 如果配置了多个中转平台的 API KEY
|
||
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
|
||
if task.ChannelId != "" && task.ChannelId != s.Name {
|
||
if s.retryCount[task.Id] > 5 {
|
||
s.db.Model(model.MidJourneyJob{Id: task.Id}).Delete(&model.MidJourneyJob{})
|
||
continue
|
||
}
|
||
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
|
||
s.taskQueue.RPush(task)
|
||
s.retryCount[task.Id]++
|
||
time.Sleep(time.Second)
|
||
continue
|
||
}
|
||
|
||
// translate prompt
|
||
if utils.HasChinese(task.Prompt) {
|
||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
|
||
if err == nil {
|
||
task.Prompt = content
|
||
} else {
|
||
logger.Warnf("error with translate prompt: %v", err)
|
||
}
|
||
}
|
||
// translate negative prompt
|
||
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
|
||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt))
|
||
if err == nil {
|
||
task.NegPrompt = content
|
||
} else {
|
||
logger.Warnf("error with translate prompt: %v", err)
|
||
}
|
||
}
|
||
|
||
var job model.MidJourneyJob
|
||
tx := s.db.Where("id = ?", task.Id).First(&job)
|
||
if tx.Error != nil {
|
||
logger.Error("任务不存在,任务ID:", task.TaskId)
|
||
continue
|
||
}
|
||
|
||
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
|
||
var res ImageRes
|
||
switch task.Type {
|
||
case types.TaskImage:
|
||
res, err = s.Client.Imagine(task)
|
||
break
|
||
case types.TaskUpscale:
|
||
res, err = s.Client.Upscale(task)
|
||
break
|
||
case types.TaskVariation:
|
||
res, err = s.Client.Variation(task)
|
||
break
|
||
case types.TaskBlend:
|
||
res, err = s.Client.Blend(task)
|
||
break
|
||
case types.TaskSwapFace:
|
||
res, err = s.Client.SwapFace(task)
|
||
break
|
||
}
|
||
|
||
if err != nil || (res.Code != 1 && res.Code != 22) {
|
||
var errMsg string
|
||
if err != nil {
|
||
errMsg = err.Error()
|
||
} else {
|
||
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
|
||
}
|
||
|
||
logger.Error("绘画任务执行失败:", errMsg)
|
||
job.Progress = failedProgress
|
||
job.ErrMsg = errMsg
|
||
// update the task progress
|
||
s.db.Updates(&job)
|
||
// 任务失败,通知前端
|
||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed})
|
||
continue
|
||
}
|
||
logger.Infof("任务提交成功:%+v", res)
|
||
// 更新任务 ID/频道
|
||
job.TaskId = res.Result
|
||
job.MessageId = res.Result
|
||
job.ChannelId = s.Name
|
||
s.db.Updates(&job)
|
||
}
|
||
}
|
||
|
||
func (s *Service) Stop() {
|
||
s.running = false
|
||
}
|
||
|
||
type CBReq struct {
|
||
Id string `json:"id"`
|
||
Action string `json:"action"`
|
||
Status string `json:"status"`
|
||
Prompt string `json:"prompt"`
|
||
PromptEn string `json:"promptEn"`
|
||
Description string `json:"description"`
|
||
SubmitTime int64 `json:"submitTime"`
|
||
StartTime int64 `json:"startTime"`
|
||
FinishTime int64 `json:"finishTime"`
|
||
Progress string `json:"progress"`
|
||
ImageUrl string `json:"imageUrl"`
|
||
FailReason interface{} `json:"failReason"`
|
||
Properties struct {
|
||
FinalPrompt string `json:"finalPrompt"`
|
||
} `json:"properties"`
|
||
}
|
||
|
||
func (s *Service) Notify(job model.MidJourneyJob) error {
|
||
task, err := s.Client.QueryTask(job.TaskId)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 任务执行失败了
|
||
if task.FailReason != "" {
|
||
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
|
||
"progress": failedProgress,
|
||
"err_msg": task.FailReason,
|
||
})
|
||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
|
||
return fmt.Errorf("task failed: %v", task.FailReason)
|
||
}
|
||
|
||
if len(task.Buttons) > 0 {
|
||
job.Hash = GetImageHash(task.Buttons[0].CustomId)
|
||
}
|
||
oldProgress := job.Progress
|
||
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
|
||
job.Prompt = task.PromptEn
|
||
if task.ImageUrl != "" {
|
||
job.OrgURL = task.ImageUrl
|
||
}
|
||
tx := s.db.Updates(&job)
|
||
if tx.Error != nil {
|
||
return fmt.Errorf("error with update database: %v", tx.Error)
|
||
}
|
||
// 通知前端更新任务进度
|
||
if oldProgress != job.Progress {
|
||
message := sd.Running
|
||
if job.Progress == 100 {
|
||
message = sd.Finished
|
||
}
|
||
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func GetImageHash(action string) string {
|
||
split := strings.Split(action, "::")
|
||
if len(split) > 5 {
|
||
return split[4]
|
||
}
|
||
return split[len(split)-1]
|
||
}
|