refactor midjourney service, use api key in database

This commit is contained in:
RockYang
2024-08-06 18:30:57 +08:00
parent 72b1515b68
commit 6a8b4ee2f1
29 changed files with 585 additions and 1203 deletions

View File

@@ -11,10 +11,11 @@ import (
"fmt"
"geekai/core/types"
"geekai/service"
"geekai/service/sd"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"github.com/go-redis/redis/v8"
"strings"
"time"
@@ -23,127 +24,112 @@ import (
// Service MJ 绘画服务
type Service struct {
Name string // service Name
Client Client // MJ Client
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
running bool
retryCount map[uint]int
client *Client // MJ Client
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
uploaderManager *oss.UploaderManager
}
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager) *Service {
return &Service{
Name: name,
db: db,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
Client: cli,
running: true,
retryCount: make(map[uint]int),
db: db,
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli),
client: client,
Clients: types.NewLMap[uint, *types.WsClient](),
uploaderManager: manager,
}
}
const failedProgress = 101
func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
for s.running {
var task types.MjTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
// 如果配置了多个中转平台的 API KEY
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
if task.ChannelId != "" && task.ChannelId != s.Name {
if s.retryCount[task.Id] > 5 {
s.db.Model(model.MidJourneyJob{Id: task.Id}).Delete(&model.MidJourneyJob{})
logger.Info("Starting MidJourney job consumer for service")
go func() {
for {
var task types.MjTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
s.taskQueue.RPush(task)
s.retryCount[task.Id]++
time.Sleep(time.Second)
continue
}
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt), "gpt-4o-mini")
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini")
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
}
// translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt), "gpt-4o-mini")
if err == nil {
task.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
var job model.MidJourneyJob
tx := s.db.Where("id = ?", task.Id).First(&job)
if tx.Error != nil {
logger.Error("任务不存在任务ID", task.TaskId)
continue
}
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
var res ImageRes
switch task.Type {
case types.TaskImage:
res, err = s.Client.Imagine(task)
break
case types.TaskUpscale:
res, err = s.Client.Upscale(task)
break
case types.TaskVariation:
res, err = s.Client.Variation(task)
break
case types.TaskBlend:
res, err = s.Client.Blend(task)
break
case types.TaskSwapFace:
res, err = s.Client.SwapFace(task)
break
}
if err != nil || (res.Code != 1 && res.Code != 22) {
var errMsg string
if err != nil {
errMsg = err.Error()
} else {
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
// translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini")
if err == nil {
task.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = failedProgress
job.ErrMsg = errMsg
// update the task progress
// use fast mode as default
if task.Mode == "" {
task.Mode = "fast"
}
var job model.MidJourneyJob
tx := s.db.Where("id = ?", task.Id).First(&job)
if tx.Error != nil {
logger.Error("任务不存在任务ID", task.TaskId)
continue
}
logger.Infof("handle a new MidJourney task: %+v", task)
var res ImageRes
switch task.Type {
case types.TaskImage:
res, err = s.client.Imagine(task)
break
case types.TaskUpscale:
res, err = s.client.Upscale(task)
break
case types.TaskVariation:
res, err = s.client.Variation(task)
break
case types.TaskBlend:
res, err = s.client.Blend(task)
break
case types.TaskSwapFace:
res, err = s.client.SwapFace(task)
break
}
if err != nil || (res.Code != 1 && res.Code != 22) {
var errMsg string
if err != nil {
errMsg = err.Error()
} else {
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
}
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = service.FailTaskProgress
job.ErrMsg = errMsg
// update the task progress
s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
continue
}
logger.Infof("任务提交成功:%+v", res)
// 更新任务 ID/频道
job.TaskId = res.Result
job.MessageId = res.Result
job.ChannelId = res.Channel
s.db.Updates(&job)
// 任务失败,通知前端
s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed})
continue
}
logger.Infof("任务提交成功:%+v", res)
// 更新任务 ID/频道
job.TaskId = res.Result
job.MessageId = res.Result
job.ChannelId = s.Name
s.db.Updates(&job)
}
}
func (s *Service) Stop() {
s.running = false
}()
}
type CBReq struct {
@@ -164,46 +150,6 @@ type CBReq struct {
} `json:"properties"`
}
func (s *Service) Notify(job model.MidJourneyJob) error {
task, err := s.Client.QueryTask(job.TaskId)
if err != nil {
return err
}
// 任务执行失败了
if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": failedProgress,
"err_msg": task.FailReason,
})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
return fmt.Errorf("task failed: %v", task.FailReason)
}
if len(task.Buttons) > 0 {
job.Hash = GetImageHash(task.Buttons[0].CustomId)
}
oldProgress := job.Progress
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
job.Prompt = task.PromptEn
if task.ImageUrl != "" {
job.OrgURL = task.ImageUrl
}
tx := s.db.Updates(&job)
if tx.Error != nil {
return fmt.Errorf("error with update database: %v", tx.Error)
}
// 通知前端更新任务进度
if oldProgress != job.Progress {
message := sd.Running
if job.Progress == 100 {
message = sd.Finished
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
}
return nil
}
func GetImageHash(action string) string {
split := strings.Split(action, "::")
if len(split) > 5 {
@@ -211,3 +157,143 @@ func GetImageHash(action string) string {
}
return split[len(split)-1]
}
func (s *Service) CheckTaskNotify() {
go func() {
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
cli := s.Clients.Get(uint(message.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (s *Service) DownloadImages() {
go func() {
var items []model.MidJourneyJob
for {
res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
if res.Error != nil {
continue
}
// download images
for _, v := range items {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
// 如果是返回的是 discord 图片地址,则使用代理下载
proxy := false
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
proxy = true
}
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
if err != nil {
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
continue
} else {
logger.Infof("download image %s successfully.", v.OrgURL)
}
v.ImgURL = imgURL
s.db.Updates(&v)
cli := s.Clients.Get(uint(v.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(service.TaskStatusFinished))
if err != nil {
continue
}
}
time.Sleep(time.Second * 5)
}
}()
}
// PushTask push a new mj task in to task queue
func (s *Service) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.MidJourneyJob
for {
res := s.db.Where("progress < ?", 100).Where("channel_id <> ?", "").Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range jobs {
// 10 分钟还没完成的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时"
s.db.Updates(&job)
continue
}
task, err := s.client.QueryTask(job.TaskId, job.ChannelId)
if err != nil {
logger.Errorf("error with query task: %v", err)
continue
}
// 任务执行失败了
if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": service.FailTaskProgress,
"err_msg": task.FailReason,
})
logger.Errorf("task failed: %v", task.FailReason)
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
continue
}
if len(task.Buttons) > 0 {
job.Hash = GetImageHash(task.Buttons[0].CustomId)
}
oldProgress := job.Progress
job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
job.Prompt = task.PromptEn
if task.ImageUrl != "" {
job.OrgURL = task.ImageUrl
}
err = s.db.Updates(&job).Error
if err != nil {
logger.Errorf("error with update database: %v", err)
continue
}
// 通知前端更新任务进度
if oldProgress != job.Progress {
message := service.TaskStatusRunning
if job.Progress == 100 {
message = service.TaskStatusFinished
}
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
}
}
time.Sleep(time.Second * 5)
}
}()
}