diff --git a/CHANGELOG.md b/CHANGELOG.md index a1eddfdf..9f898a32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ * 功能优化:优化MJ,SD,DALL-E 任务列表页面,显示失败任务的错误信息,删除失败任务可以恢复扣减算力 * Bug修复:修复后台拖动排序组件 Bug * 功能优化:更新数据库失败时候显示具体的的报错信息 +* Bug修复:修复管理后台对话详情页内容显示异常问题 ## v4.1.1 * Bug修复:修复 GPT 模型 function call 调用后没有输出的问题 diff --git a/api/config.sample.toml b/api/config.sample.toml index eb33a5d7..a52f1023 100644 --- a/api/config.sample.toml +++ b/api/config.sample.toml @@ -64,18 +64,7 @@ TikaHost = "http://tika:9998" Bucket = "chatgpt-plus" SubDir = "" Domain = "" - -[[MjProxyConfigs]] - Enabled = true - ApiURL = "http://midjourney-proxy:8082" - ApiKey = "sk-geekmaster" - -[[MjPlusConfigs]] - Enabled = false - ApiURL = "https://api.chat-plus.net" - Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo - ApiKey = "sk-xxx" - + [[SdConfigs]] Enabled = false ApiURL = "" diff --git a/api/core/types/config.go b/api/core/types/config.go index 67845571..984a6982 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -12,22 +12,20 @@ import ( ) type AppConfig struct { - Path string `toml:"-"` - Listen string - Session Session - AdminSession Session - ProxyURL string - MysqlDns string // mysql 连接地址 - StaticDir string // 静态资源目录 - StaticUrl string // 静态资源 URL - Redis RedisConfig // redis 连接信息 - ApiConfig ApiConfig // ChatPlus API authorization configs - SMS SMSConfig // send mobile message config - OSS OSSConfig // OSS config - MjProxyConfigs []MjProxyConfig // MJ proxy config - MjPlusConfigs []MjPlusConfig // MJ plus config - WeChatBot bool // 是否启用微信机器人 - SdConfigs []StableDiffusionConfig // sd AI draw service pool + Path string `toml:"-"` + Listen string + Session Session + AdminSession Session + ProxyURL string + MysqlDns string // mysql 连接地址 + StaticDir string // 静态资源目录 + StaticUrl string // 静态资源 URL + Redis RedisConfig // redis 连接信息 + ApiConfig ApiConfig // ChatPlus API authorization configs + SMS SMSConfig // send mobile message config + OSS OSSConfig // OSS config + WeChatBot bool // 是否启用微信机器人 + SdConfigs []StableDiffusionConfig // sd AI draw service pool XXLConfig XXLConfig AlipayConfig AlipayConfig // 支付宝支付渠道配置 @@ -188,6 +186,7 @@ type SystemConfig struct { ContextDeep int `json:"context_deep,omitempty"` SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词 + MjMode string `json:"mj_mode"` // midjourney 默认的API模式,relax, fast, turbo IndexBgURL string `json:"index_bg_url"` // 前端首页背景图片 IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单 diff --git a/api/core/types/task.go b/api/core/types/task.go index 0fb451b7..72b0d7c6 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -27,7 +27,6 @@ type MjTask struct { Id uint `json:"id"` TaskId string `json:"task_id"` ImgArr []string `json:"img_arr"` - ChannelId string `json:"channel_id"` Type TaskType `json:"type"` UserId int `json:"user_id"` Prompt string `json:"prompt,omitempty"` @@ -37,6 +36,8 @@ type MjTask struct { MessageId string `json:"message_id,omitempty"` MessageHash string `json:"message_hash,omitempty"` RetryCount int `json:"retry_count"` + ChannelId string `json:"channel_id"` // 渠道ID,用来区分是哪个渠道创建的任务,一个任务的 create 和 action 操作必须要再同一个渠道 + Mode string `json:"mode"` // 绘画模式,relax, fast, turbo } type SdTask struct { diff --git a/api/handler/admin/chat_role_handler.go b/api/handler/admin/chat_role_handler.go index 74b3d398..3776ddc9 100644 --- a/api/handler/admin/chat_role_handler.go +++ b/api/handler/admin/chat_role_handler.go @@ -8,6 +8,7 @@ package admin // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( + "fmt" "geekai/core" "geekai/core/types" "geekai/handler" @@ -45,6 +46,12 @@ func (h *ChatRoleHandler) Save(c *gin.Context) { role.Id = data.Id if data.CreatedAt > 0 { role.CreatedAt = time.Unix(data.CreatedAt, 0) + } else { + err = h.DB.Where("marker", data.Key).First(&role).Error + if err == nil { + resp.ERROR(c, fmt.Sprintf("角色 %s 已存在", data.Key)) + return + } } err = h.DB.Save(&role).Error if err != nil { diff --git a/api/handler/admin/config_handler.go b/api/handler/admin/config_handler.go index c50c8d6f..b3d22705 100644 --- a/api/handler/admin/config_handler.go +++ b/api/handler/admin/config_handler.go @@ -12,7 +12,6 @@ import ( "geekai/core/types" "geekai/handler" "geekai/service" - "geekai/service/mj" "geekai/service/sd" "geekai/store" "geekai/store/model" @@ -28,15 +27,13 @@ type ConfigHandler struct { handler.BaseHandler levelDB *store.LevelDB licenseService *service.LicenseService - mjServicePool *mj.ServicePool sdServicePool *sd.ServicePool } -func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, mjPool *mj.ServicePool, sdPool *sd.ServicePool) *ConfigHandler { +func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, sdPool *sd.ServicePool) *ConfigHandler { return &ConfigHandler{ BaseHandler: handler.BaseHandler{App: app, DB: db}, levelDB: levelDB, - mjServicePool: mjPool, sdServicePool: sdPool, licenseService: licenseService, } @@ -146,58 +143,3 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) { license := h.licenseService.GetLicense() resp.SUCCESS(c, license) } - -// GetAppConfig 获取内置配置 -func (h *ConfigHandler) GetAppConfig(c *gin.Context) { - resp.SUCCESS(c, gin.H{ - "mj_plus": h.App.Config.MjPlusConfigs, - "mj_proxy": h.App.Config.MjProxyConfigs, - "sd": h.App.Config.SdConfigs, - }) -} - -// SaveDrawingConfig 保存AI绘画配置 -func (h *ConfigHandler) SaveDrawingConfig(c *gin.Context) { - var data struct { - Sd []types.StableDiffusionConfig `json:"sd"` - MjPlus []types.MjPlusConfig `json:"mj_plus"` - MjProxy []types.MjProxyConfig `json:"mj_proxy"` - } - if err := c.ShouldBindJSON(&data); err != nil { - resp.ERROR(c, types.InvalidArgs) - return - } - - changed := false - if configChanged(data.Sd, h.App.Config.SdConfigs) { - logger.Debugf("SD 配置变动了") - h.App.Config.SdConfigs = data.Sd - h.sdServicePool.InitServices(data.Sd) - changed = true - } - - if configChanged(data.MjPlus, h.App.Config.MjPlusConfigs) || configChanged(data.MjProxy, h.App.Config.MjProxyConfigs) { - logger.Debugf("MidJourney 配置变动了") - h.App.Config.MjPlusConfigs = data.MjPlus - h.App.Config.MjProxyConfigs = data.MjProxy - h.mjServicePool.InitServices(data.MjPlus, data.MjProxy) - changed = true - } - - if changed { - err := core.SaveConfig(h.App.Config) - if err != nil { - resp.ERROR(c, "更新配置文档失败!") - return - } - } - - resp.SUCCESS(c) - -} - -func configChanged(c1 interface{}, c2 interface{}) bool { - encode1 := utils.JsonEncode(c1) - encode2 := utils.JsonEncode(c2) - return utils.Md5(encode1) != utils.Md5(encode2) -} diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 079a9a17..c3cd14d0 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -30,15 +30,15 @@ import ( type MidJourneyHandler struct { BaseHandler - pool *mj.ServicePool + service *mj.Service snowflake *service.Snowflake uploader *oss.UploaderManager } -func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler { +func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager) *MidJourneyHandler { return &MidJourneyHandler{ snowflake: snowflake, - pool: pool, + service: service, uploader: manager, BaseHandler: BaseHandler{ App: app, @@ -59,11 +59,6 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool { return false } - if !h.pool.HasAvailableService() { - resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!") - return false - } - return true } @@ -85,7 +80,7 @@ func (h *MidJourneyHandler) Client(c *gin.Context) { } client := types.NewWsClient(ws) - h.pool.Clients.Put(uint(userId), client) + h.service.Clients.Put(uint(userId), client) logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) } @@ -201,7 +196,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { return } - h.pool.PushTask(types.MjTask{ + h.service.PushTask(types.MjTask{ Id: job.Id, TaskId: taskId, Type: types.TaskType(data.TaskType), @@ -210,9 +205,10 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { Params: params, UserId: userId, ImgArr: data.ImgArr, + Mode: h.App.SysConfig.MjMode, }) - client := h.pool.Clients.Get(uint(job.UserId)) + client := h.service.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } @@ -273,7 +269,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { return } - h.pool.PushTask(types.MjTask{ + h.service.PushTask(types.MjTask{ Id: job.Id, Type: types.TaskUpscale, UserId: userId, @@ -281,9 +277,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { Index: data.Index, MessageId: data.MessageId, MessageHash: data.MessageHash, + Mode: h.App.SysConfig.MjMode, }) - client := h.pool.Clients.Get(uint(job.UserId)) + client := h.service.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } @@ -337,7 +334,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { return } - h.pool.PushTask(types.MjTask{ + h.service.PushTask(types.MjTask{ Id: job.Id, Type: types.TaskVariation, UserId: userId, @@ -345,9 +342,10 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { ChannelId: data.ChannelId, MessageId: data.MessageId, MessageHash: data.MessageHash, + Mode: h.App.SysConfig.MjMode, }) - client := h.pool.Clients.Get(uint(job.UserId)) + client := h.service.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } @@ -500,7 +498,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) { logger.Error("remove image failed: ", err) } - client := h.pool.Clients.Get(uint(job.UserId)) + client := h.service.Clients.Get(uint(job.UserId)) if client != nil { _ = client.Send([]byte("Task Updated")) } diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index 73381d39..ff5320ef 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -330,7 +330,7 @@ func (h *SdJobHandler) Remove(c *gin.Context) { client := h.pool.Clients.Get(uint(job.UserId)) if client != nil { - _ = client.Send([]byte(sd.Finished)) + _ = client.Send([]byte(service.TaskStatusFinished)) } resp.SUCCESS(c) diff --git a/api/main.go b/api/main.go index d8d2cdba..287c6c07 100644 --- a/api/main.go +++ b/api/main.go @@ -161,13 +161,12 @@ func main() { return service.NewCaptchaService(config.ApiConfig) }), fx.Provide(oss.NewUploaderManager), - fx.Provide(mj.NewService), fx.Provide(dalle.NewService), - fx.Invoke(func(service *dalle.Service) { - service.Run() - service.CheckTaskNotify() - service.DownloadImages() - service.CheckTaskStatus() + fx.Invoke(func(s *dalle.Service) { + s.Run() + s.CheckTaskNotify() + s.DownloadImages() + s.CheckTaskStatus() }), // 邮件服务 @@ -190,14 +189,13 @@ func main() { }), // MidJourney service pool - fx.Provide(mj.NewServicePool), - fx.Invoke(func(pool *mj.ServicePool, config *types.AppConfig) { - pool.InitServices(config.MjPlusConfigs, config.MjProxyConfigs) - if pool.HasAvailableService() { - pool.DownloadImages() - pool.CheckTaskNotify() - pool.SyncTaskProgress() - } + fx.Provide(mj.NewService), + fx.Provide(mj.NewClient), + fx.Invoke(func(s *mj.Service) { + s.Run() + s.SyncTaskProgress() + s.CheckTaskNotify() + s.DownloadImages() }), // Stable Diffusion 机器人 @@ -317,8 +315,6 @@ func main() { group.GET("config/get", h.Get) group.POST("active", h.Active) group.GET("config/get/license", h.GetLicense) - group.GET("config/get/app", h.GetAppConfig) - group.POST("config/update/draw", h.SaveDrawingConfig) }), fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) { group := s.Engine.Group("/api/admin/") diff --git a/api/service/dalle/service.go b/api/service/dalle/service.go index 5c5a1929..5b1d4ab8 100644 --- a/api/service/dalle/service.go +++ b/api/service/dalle/service.go @@ -14,7 +14,6 @@ import ( logger2 "geekai/logger" "geekai/service" "geekai/service/oss" - "geekai/service/sd" "geekai/store" "geekai/store/model" "geekai/utils" @@ -70,10 +69,10 @@ func (s *Service) Run() { if err != nil { logger.Errorf("error with image task: %v", err) s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ - "progress": 101, + "progress": service.FailTaskProgress, "err_msg": err.Error(), }) - s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed}) + s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed}) } } }() @@ -191,7 +190,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { return "", fmt.Errorf("err with update database: %v", tx.Error) } - s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished}) + s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed}) var content string if sync { imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url) @@ -208,7 +207,7 @@ func (s *Service) CheckTaskNotify() { go func() { logger.Info("Running DALL-E task notify checking ...") for { - var message sd.NotifyMessage + var message service.NotifyMessage err := s.notifyQueue.LPop(&message) if err != nil { continue @@ -239,7 +238,7 @@ func (s *Service) CheckTaskStatus() { for _, job := range jobs { // 超时的任务标记为失败 if time.Now().Sub(job.CreatedAt) > time.Minute*10 { - job.Progress = 101 + job.Progress = service.FailTaskProgress job.ErrMsg = "任务超时" s.db.Updates(&job) } @@ -292,6 +291,6 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, if res.Error != nil { return "", err } - s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Finished}) + s.notifyQueue.RPush(service.NotifyMessage{UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished}) return imgURL, nil } diff --git a/api/service/mj/client.go b/api/service/mj/client.go index 504553f0..450b7d8b 100644 --- a/api/service/mj/client.go +++ b/api/service/mj/client.go @@ -7,15 +7,28 @@ package mj // * @Author yangjian102621@163.com // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -import "geekai/core/types" +import ( + "encoding/base64" + "errors" + "fmt" + "geekai/core/types" + logger2 "geekai/logger" + "geekai/service" + "geekai/store/model" + "geekai/utils" + "github.com/imroc/req/v3" + "gorm.io/gorm" + "io" + "time" -type Client interface { - Imagine(task types.MjTask) (ImageRes, error) - Blend(task types.MjTask) (ImageRes, error) - SwapFace(task types.MjTask) (ImageRes, error) - Upscale(task types.MjTask) (ImageRes, error) - Variation(task types.MjTask) (ImageRes, error) - QueryTask(taskId string) (QueryRes, error) + "github.com/gin-gonic/gin" +) + +// Client MidJourney client +type Client struct { + client *req.Client + licenseService *service.LicenseService + db *gorm.DB } type ImageReq struct { @@ -33,7 +46,8 @@ type ImageRes struct { Description string `json:"description"` Properties struct { } `json:"properties"` - Result string `json:"result"` + Result string `json:"result"` + Channel string `json:"channel,omitempty"` } type ErrRes struct { @@ -66,3 +80,184 @@ type QueryRes struct { Status string `json:"status"` SubmitTime int `json:"submitTime"` } + +var logger = logger2.GetLogger() + +func NewClient(licenseService *service.LicenseService, db *gorm.DB) *Client { + return &Client{ + client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"), + licenseService: licenseService, + db: db, + } +} + +func (c *Client) Imagine(task types.MjTask) (ImageRes, error) { + apiPath := fmt.Sprintf("mj-%s/mj/submit/imagine", task.Mode) + prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params) + if task.NegPrompt != "" { + prompt += fmt.Sprintf(" --no %s", task.NegPrompt) + } + body := ImageReq{ + BotType: "MID_JOURNEY", + Prompt: prompt, + Base64Array: make([]string, 0), + } + // 生成图片 Base64 编码 + if len(task.ImgArr) > 0 { + imageData, err := utils.DownloadImage(task.ImgArr[0], "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) + } + + } + return c.doRequest(body, apiPath, task.ChannelId) +} + +// Blend 融图 +func (c *Client) Blend(task types.MjTask) (ImageRes, error) { + apiPath := fmt.Sprintf("mj-%s/mj/submit/blend", task.Mode) + body := ImageReq{ + BotType: "MID_JOURNEY", + Dimensions: "SQUARE", + Base64Array: make([]string, 0), + } + // 生成图片 Base64 编码 + if len(task.ImgArr) > 0 { + for _, imgURL := range task.ImgArr { + imageData, err := utils.DownloadImage(imgURL, "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) + } + } + } + return c.doRequest(body, apiPath, task.ChannelId) +} + +// SwapFace 换脸 +func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) { + apiPath := fmt.Sprintf("mj-%s/mj/insight-face/swap", task.Mode) + // 生成图片 Base64 编码 + if len(task.ImgArr) != 2 { + return ImageRes{}, errors.New("参数错误,必须上传2张图片") + } + var sourceBase64 string + var targetBase64 string + imageData, err := utils.DownloadImage(task.ImgArr[0], "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) + } + imageData, err = utils.DownloadImage(task.ImgArr[1], "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) + } + + body := gin.H{ + "sourceBase64": sourceBase64, + "targetBase64": targetBase64, + "accountFilter": gin.H{ + "instanceId": "", + }, + "state": "", + } + return c.doRequest(body, apiPath, task.ChannelId) +} + +// Upscale 放大指定的图片 +func (c *Client) Upscale(task types.MjTask) (ImageRes, error) { + body := map[string]string{ + "customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), + "taskId": task.MessageId, + } + apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode) + return c.doRequest(body, apiPath, task.ChannelId) +} + +// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 +func (c *Client) Variation(task types.MjTask) (ImageRes, error) { + body := map[string]string{ + "customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), + "taskId": task.MessageId, + } + apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode) + + return c.doRequest(body, apiPath, task.ChannelId) +} + +func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) { + var res ImageRes + var errRes ErrRes + session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true) + if channel != "" { + session = session.Where("api_url", channel) + } + + var apiKey model.ApiKey + err := session.Order("last_used_at ASC").First(&apiKey).Error + if err != nil { + return ImageRes{}, fmt.Errorf("no available MidJourney api key: %v", err) + } + + if err = c.licenseService.IsValidApiURL(apiKey.ApiURL); err != nil { + return ImageRes{}, err + } + + apiURL := fmt.Sprintf("%s/%s", apiKey.ApiURL, apiPath) + logger.Info("API URL: ", apiURL) + r, err := req.C().R(). + SetHeader("Authorization", "Bearer "+apiKey.Value). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + errMsg := err.Error() + if r != nil { + errStr, _ := io.ReadAll(r.Body) + logger.Error("请求 API 出错:", string(errStr)) + errMsg = errMsg + " " + string(errStr) + } + return ImageRes{}, fmt.Errorf("请求 API 出错:%v", errMsg) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + // update the api key last used time + if err = c.db.Model(&apiKey).Update("last_used_at", time.Now().Unix()).Error; err != nil { + logger.Error("update api key last used time error: ", err) + } + res.Channel = apiKey.ApiURL + return res, nil +} + +func (c *Client) QueryTask(taskId string, channel string) (QueryRes, error) { + var apiKey model.ApiKey + err := c.db.Where("type", "mj").Where("enabled", true).Where("api_url", channel).First(&apiKey).Error + if err != nil { + return QueryRes{}, fmt.Errorf("no available MidJourney api key: %v", err) + } + apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", apiKey.ApiURL, taskId) + var res QueryRes + r, err := c.client.R().SetHeader("Authorization", "Bearer "+apiKey.Value). + SetSuccessResult(&res). + Get(apiURL) + + if err != nil { + return QueryRes{}, err + } + + if r.IsErrorState() { + return QueryRes{}, errors.New("error status:" + r.Status) + } + + return res, nil +} diff --git a/api/service/mj/plus_client.go b/api/service/mj/plus_client.go deleted file mode 100644 index 736c52b2..00000000 --- a/api/service/mj/plus_client.go +++ /dev/null @@ -1,211 +0,0 @@ -package mj - -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// * Copyright 2023 The Geek-AI Authors. All rights reserved. -// * Use of this source code is governed by a Apache-2.0 license -// * that can be found in the LICENSE file. -// * @Author yangjian102621@163.com -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -import ( - "encoding/base64" - "errors" - "fmt" - "geekai/core/types" - "geekai/service" - "geekai/utils" - "github.com/imroc/req/v3" - "io" - "time" - - "github.com/gin-gonic/gin" -) - -// PlusClient MidJourney Plus ProxyClient -type PlusClient struct { - Config types.MjPlusConfig - apiURL string - client *req.Client - licenseService *service.LicenseService -} - -func NewPlusClient(config types.MjPlusConfig, licenseService *service.LicenseService) *PlusClient { - return &PlusClient{ - Config: config, - apiURL: config.ApiURL, - client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"), - licenseService: licenseService, - } -} - -func (c *PlusClient) preCheck() error { - return c.licenseService.IsValidApiURL(c.Config.ApiURL) -} - -func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) { - if err := c.preCheck(); err != nil { - return ImageRes{}, err - } - - apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode) - prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params) - if task.NegPrompt != "" { - prompt += fmt.Sprintf(" --no %s", task.NegPrompt) - } - body := ImageReq{ - BotType: "MID_JOURNEY", - Prompt: prompt, - Base64Array: make([]string, 0), - } - // 生成图片 Base64 编码 - if len(task.ImgArr) > 0 { - imageData, err := utils.DownloadImage(task.ImgArr[0], "") - if err != nil { - logger.Error("error with download image: ", err) - } else { - body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) - } - - } - return c.doRequest(body, apiURL) -} - -// Blend 融图 -func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) { - if err := c.preCheck(); err != nil { - return ImageRes{}, err - } - - apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode) - logger.Info("API URL: ", apiURL) - body := ImageReq{ - BotType: "MID_JOURNEY", - Dimensions: "SQUARE", - Base64Array: make([]string, 0), - } - // 生成图片 Base64 编码 - if len(task.ImgArr) > 0 { - for _, imgURL := range task.ImgArr { - imageData, err := utils.DownloadImage(imgURL, "") - if err != nil { - logger.Error("error with download image: ", err) - } else { - body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) - } - } - } - return c.doRequest(body, apiURL) -} - -// SwapFace 换脸 -func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) { - if err := c.preCheck(); err != nil { - return ImageRes{}, err - } - - apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode) - // 生成图片 Base64 编码 - if len(task.ImgArr) != 2 { - return ImageRes{}, errors.New("参数错误,必须上传2张图片") - } - var sourceBase64 string - var targetBase64 string - imageData, err := utils.DownloadImage(task.ImgArr[0], "") - if err != nil { - logger.Error("error with download image: ", err) - } else { - sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) - } - imageData, err = utils.DownloadImage(task.ImgArr[1], "") - if err != nil { - logger.Error("error with download image: ", err) - } else { - targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) - } - - body := gin.H{ - "sourceBase64": sourceBase64, - "targetBase64": targetBase64, - "accountFilter": gin.H{ - "instanceId": "", - }, - "state": "", - } - return c.doRequest(body, apiURL) -} - -// Upscale 放大指定的图片 -func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) { - if err := c.preCheck(); err != nil { - return ImageRes{}, err - } - - body := map[string]string{ - "customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), - "taskId": task.MessageId, - } - apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode) - return c.doRequest(body, apiURL) -} - -// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 -func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) { - if err := c.preCheck(); err != nil { - return ImageRes{}, err - } - - body := map[string]string{ - "customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), - "taskId": task.MessageId, - } - apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode) - - return c.doRequest(body, apiURL) -} - -func (c *PlusClient) doRequest(body interface{}, apiURL string) (ImageRes, error) { - var res ImageRes - var errRes ErrRes - logger.Info("API URL: ", apiURL) - r, err := req.C().R(). - SetHeader("Authorization", "Bearer "+c.Config.ApiKey). - SetBody(body). - SetSuccessResult(&res). - SetErrorResult(&errRes). - Post(apiURL) - if err != nil { - errMsg := err.Error() - if r != nil { - errStr, _ := io.ReadAll(r.Body) - logger.Error("请求 API 出错:", string(errStr)) - errMsg = errMsg + " " + string(errStr) - } - return ImageRes{}, fmt.Errorf("请求 API 出错:%v", errMsg) - } - - if r.IsErrorState() { - return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) - } - - return res, nil -} - -func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) { - apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId) - var res QueryRes - r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey). - SetSuccessResult(&res). - Get(apiURL) - - if err != nil { - return QueryRes{}, err - } - - if r.IsErrorState() { - return QueryRes{}, errors.New("error status:" + r.Status) - } - - return res, nil -} - -var _ Client = &PlusClient{} diff --git a/api/service/mj/pool.go b/api/service/mj/pool.go deleted file mode 100644 index 0e319fe6..00000000 --- a/api/service/mj/pool.go +++ /dev/null @@ -1,215 +0,0 @@ -package mj - -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// * Copyright 2023 The Geek-AI Authors. All rights reserved. -// * Use of this source code is governed by a Apache-2.0 license -// * that can be found in the LICENSE file. -// * @Author yangjian102621@163.com -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -import ( - "geekai/core/types" - logger2 "geekai/logger" - "geekai/service" - "geekai/service/oss" - "geekai/service/sd" - "geekai/store" - "geekai/store/model" - "geekai/utils" - "github.com/go-redis/redis/v8" - "strings" - "time" - - "gorm.io/gorm" -) - -// ServicePool Mj service pool -type ServicePool struct { - services []*Service - taskQueue *store.RedisQueue - notifyQueue *store.RedisQueue - db *gorm.DB - uploaderManager *oss.UploaderManager - Clients *types.LMap[uint, *types.WsClient] // UserId => Client - licenseService *service.LicenseService -} - -var logger = logger2.GetLogger() - -func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ServicePool { - services := make([]*Service, 0) - taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli) - notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli) - return &ServicePool{ - taskQueue: taskQueue, - notifyQueue: notifyQueue, - services: services, - uploaderManager: manager, - db: db, - Clients: types.NewLMap[uint, *types.WsClient](), - licenseService: licenseService, - } -} - -func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfigs []types.MjProxyConfig) { - // stop old service - for _, s := range p.services { - s.Stop() - } - p.services = make([]*Service, 0) - - for _, config := range plusConfigs { - if config.Enabled == false { - continue - } - - cli := NewPlusClient(config, p.licenseService) - name := utils.Md5(config.ApiURL) - plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli) - go func() { - plusService.Run() - }() - p.services = append(p.services, plusService) - } - - // for mid-journey proxy - for _, config := range proxyConfigs { - if config.Enabled == false { - continue - } - cli := NewProxyClient(config) - name := utils.Md5(config.ApiURL) - proxyService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli) - go func() { - proxyService.Run() - }() - p.services = append(p.services, proxyService) - } -} - -func (p *ServicePool) CheckTaskNotify() { - go func() { - for { - var message sd.NotifyMessage - err := p.notifyQueue.LPop(&message) - if err != nil { - continue - } - cli := p.Clients.Get(uint(message.UserId)) - if cli == nil { - continue - } - err = cli.Send([]byte(message.Message)) - if err != nil { - continue - } - } - }() -} - -func (p *ServicePool) DownloadImages() { - go func() { - var items []model.MidJourneyJob - for { - res := p.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items) - if res.Error != nil { - continue - } - - // download images - for _, v := range items { - if v.OrgURL == "" { - continue - } - - logger.Infof("try to download image: %s", v.OrgURL) - mjService := p.getService(v.ChannelId) - if mjService == nil { - logger.Errorf("Invalid task: %+v", v) - continue - } - - task, _ := mjService.Client.QueryTask(v.TaskId) - if len(task.Buttons) > 0 { - v.Hash = GetImageHash(task.Buttons[0].CustomId) - } - // 如果是返回的是 discord 图片地址,则使用代理下载 - proxy := false - if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") { - proxy = true - } - imgURL, err := p.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy) - - if err != nil { - logger.Errorf("error with download image %s, %v", v.OrgURL, err) - continue - } else { - logger.Infof("download image %s successfully.", v.OrgURL) - } - - v.ImgURL = imgURL - p.db.Updates(&v) - - cli := p.Clients.Get(uint(v.UserId)) - if cli == nil { - continue - } - err = cli.Send([]byte(sd.Finished)) - if err != nil { - continue - } - } - - time.Sleep(time.Second * 5) - } - }() -} - -// PushTask push a new mj task in to task queue -func (p *ServicePool) PushTask(task types.MjTask) { - logger.Debugf("add a new MidJourney task to the task list: %+v", task) - p.taskQueue.RPush(task) -} - -// HasAvailableService check if it has available mj service in pool -func (p *ServicePool) HasAvailableService() bool { - return len(p.services) > 0 -} - -// SyncTaskProgress 异步拉取任务 -func (p *ServicePool) SyncTaskProgress() { - go func() { - var jobs []model.MidJourneyJob - for { - res := p.db.Where("progress < ?", 100).Find(&jobs) - if res.Error != nil { - continue - } - - for _, job := range jobs { - // 5 分钟还没完成的任务标记为失败 - if time.Now().Sub(job.CreatedAt) > time.Minute*5 { - job.Progress = 101 - job.ErrMsg = "任务超时" - p.db.Updates(&job) - continue - } - - if servicePlus := p.getService(job.ChannelId); servicePlus != nil { - _ = servicePlus.Notify(job) - } - } - - time.Sleep(time.Second * 10) - } - }() -} - -func (p *ServicePool) getService(name string) *Service { - for _, s := range p.services { - if s.Name == name { - return s - } - } - return nil -} diff --git a/api/service/mj/proxy_client.go b/api/service/mj/proxy_client.go deleted file mode 100644 index e6a557d4..00000000 --- a/api/service/mj/proxy_client.go +++ /dev/null @@ -1,185 +0,0 @@ -package mj - -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// * Copyright 2023 The Geek-AI Authors. All rights reserved. -// * Use of this source code is governed by a Apache-2.0 license -// * that can be found in the LICENSE file. -// * @Author yangjian102621@163.com -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -import ( - "encoding/base64" - "errors" - "fmt" - "geekai/core/types" - "geekai/utils" - "github.com/imroc/req/v3" - "io" -) - -// ProxyClient MidJourney Proxy Client -type ProxyClient struct { - Config types.MjProxyConfig - apiURL string -} - -func NewProxyClient(config types.MjProxyConfig) *ProxyClient { - return &ProxyClient{Config: config, apiURL: config.ApiURL} -} - -func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) { - apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL) - prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params) - if task.NegPrompt != "" { - prompt += fmt.Sprintf(" --no %s", task.NegPrompt) - } - body := ImageReq{ - Prompt: prompt, - Base64Array: make([]string, 0), - } - // 生成图片 Base64 编码 - if len(task.ImgArr) > 0 { - imageData, err := utils.DownloadImage(task.ImgArr[0], "") - if err != nil { - logger.Error("error with download image: ", err) - } else { - body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) - } - - } - logger.Info("API URL: ", apiURL) - var res ImageRes - var errRes ErrRes - r, err := req.C().R(). - SetHeader("mj-api-secret", c.Config.ApiKey). - SetBody(body). - SetSuccessResult(&res). - SetErrorResult(&errRes). - Post(apiURL) - if err != nil { - return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) - } - - if r.IsErrorState() { - errStr, _ := io.ReadAll(r.Body) - return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr)) - } - - return res, nil -} - -// Blend 融图 -func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) { - apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL) - body := ImageReq{ - Dimensions: "SQUARE", - Base64Array: make([]string, 0), - } - // 生成图片 Base64 编码 - if len(task.ImgArr) > 0 { - for _, imgURL := range task.ImgArr { - imageData, err := utils.DownloadImage(imgURL, "") - if err != nil { - logger.Error("error with download image: ", err) - } else { - body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) - } - } - } - var res ImageRes - var errRes ErrRes - r, err := req.C().R(). - SetHeader("mj-api-secret", c.Config.ApiKey). - SetBody(body). - SetSuccessResult(&res). - SetErrorResult(&errRes). - Post(apiURL) - if err != nil { - return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) - } - - if r.IsErrorState() { - return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) - } - - return res, nil -} - -// SwapFace 换脸 -func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) { - return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能,请使用 MidJourney-Plus") -} - -// Upscale 放大指定的图片 -func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) { - body := map[string]interface{}{ - "action": "UPSCALE", - "index": task.Index, - "taskId": task.MessageId, - } - apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL) - var res ImageRes - var errRes ErrRes - r, err := req.C().R(). - SetHeader("mj-api-secret", c.Config.ApiKey). - SetBody(body). - SetSuccessResult(&res). - SetErrorResult(&errRes). - Post(apiURL) - if err != nil { - return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) - } - - if r.IsErrorState() { - return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) - } - - return res, nil -} - -// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 -func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) { - body := map[string]interface{}{ - "action": "VARIATION", - "index": task.Index, - "taskId": task.MessageId, - } - apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL) - var res ImageRes - var errRes ErrRes - r, err := req.C().R(). - SetHeader("mj-api-secret", c.Config.ApiKey). - SetBody(body). - SetSuccessResult(&res). - SetErrorResult(&errRes). - Post(apiURL) - if err != nil { - return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) - } - - if r.IsErrorState() { - return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) - } - - return res, nil -} - -func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) { - apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId) - var res QueryRes - r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey). - SetSuccessResult(&res). - Get(apiURL) - - if err != nil { - return QueryRes{}, err - } - - if r.IsErrorState() { - return QueryRes{}, errors.New("error status:" + r.Status) - } - - return res, nil -} - -var _ Client = &ProxyClient{} diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 60b1fc50..4c9059fa 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -11,10 +11,11 @@ import ( "fmt" "geekai/core/types" "geekai/service" - "geekai/service/sd" + "geekai/service/oss" "geekai/store" "geekai/store/model" "geekai/utils" + "github.com/go-redis/redis/v8" "strings" "time" @@ -23,127 +24,112 @@ import ( // Service MJ 绘画服务 type Service struct { - Name string // service Name - Client Client // MJ Client - taskQueue *store.RedisQueue - notifyQueue *store.RedisQueue - db *gorm.DB - running bool - retryCount map[uint]int + client *Client // MJ Client + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + db *gorm.DB + Clients *types.LMap[uint, *types.WsClient] // UserId => Client + uploaderManager *oss.UploaderManager } -func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service { +func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager) *Service { return &Service{ - Name: name, - db: db, - taskQueue: taskQueue, - notifyQueue: notifyQueue, - Client: cli, - running: true, - retryCount: make(map[uint]int), + db: db, + taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli), + notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli), + client: client, + Clients: types.NewLMap[uint, *types.WsClient](), + uploaderManager: manager, } } -const failedProgress = 101 - func (s *Service) Run() { - logger.Infof("Starting MidJourney job consumer for %s", s.Name) - for s.running { - var task types.MjTask - err := s.taskQueue.LPop(&task) - if err != nil { - logger.Errorf("taking task with error: %v", err) - continue - } - - // 如果配置了多个中转平台的 API KEY - // U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表 - if task.ChannelId != "" && task.ChannelId != s.Name { - if s.retryCount[task.Id] > 5 { - s.db.Model(model.MidJourneyJob{Id: task.Id}).Delete(&model.MidJourneyJob{}) + logger.Info("Starting MidJourney job consumer for service") + go func() { + for { + var task types.MjTask + err := s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) continue } - logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId) - s.taskQueue.RPush(task) - s.retryCount[task.Id]++ - time.Sleep(time.Second) - continue - } - // translate prompt - if utils.HasChinese(task.Prompt) { - content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt), "gpt-4o-mini") - if err == nil { - task.Prompt = content - } else { - logger.Warnf("error with translate prompt: %v", err) + // translate prompt + if utils.HasChinese(task.Prompt) { + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini") + if err == nil { + task.Prompt = content + } else { + logger.Warnf("error with translate prompt: %v", err) + } } - } - // translate negative prompt - if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) { - content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt), "gpt-4o-mini") - if err == nil { - task.NegPrompt = content - } else { - logger.Warnf("error with translate prompt: %v", err) - } - } - - var job model.MidJourneyJob - tx := s.db.Where("id = ?", task.Id).First(&job) - if tx.Error != nil { - logger.Error("任务不存在,任务ID:", task.TaskId) - continue - } - - logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task) - var res ImageRes - switch task.Type { - case types.TaskImage: - res, err = s.Client.Imagine(task) - break - case types.TaskUpscale: - res, err = s.Client.Upscale(task) - break - case types.TaskVariation: - res, err = s.Client.Variation(task) - break - case types.TaskBlend: - res, err = s.Client.Blend(task) - break - case types.TaskSwapFace: - res, err = s.Client.SwapFace(task) - break - } - - if err != nil || (res.Code != 1 && res.Code != 22) { - var errMsg string - if err != nil { - errMsg = err.Error() - } else { - errMsg = fmt.Sprintf("%v,%s", err, res.Description) + // translate negative prompt + if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) { + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini") + if err == nil { + task.NegPrompt = content + } else { + logger.Warnf("error with translate prompt: %v", err) + } } - logger.Error("绘画任务执行失败:", errMsg) - job.Progress = failedProgress - job.ErrMsg = errMsg - // update the task progress + // use fast mode as default + if task.Mode == "" { + task.Mode = "fast" + } + + var job model.MidJourneyJob + tx := s.db.Where("id = ?", task.Id).First(&job) + if tx.Error != nil { + logger.Error("任务不存在,任务ID:", task.TaskId) + continue + } + + logger.Infof("handle a new MidJourney task: %+v", task) + var res ImageRes + switch task.Type { + case types.TaskImage: + res, err = s.client.Imagine(task) + break + case types.TaskUpscale: + res, err = s.client.Upscale(task) + break + case types.TaskVariation: + res, err = s.client.Variation(task) + break + case types.TaskBlend: + res, err = s.client.Blend(task) + break + case types.TaskSwapFace: + res, err = s.client.SwapFace(task) + break + } + + if err != nil || (res.Code != 1 && res.Code != 22) { + var errMsg string + if err != nil { + errMsg = err.Error() + } else { + errMsg = fmt.Sprintf("%v,%s", err, res.Description) + } + + logger.Error("绘画任务执行失败:", errMsg) + job.Progress = service.FailTaskProgress + job.ErrMsg = errMsg + // update the task progress + s.db.Updates(&job) + // 任务失败,通知前端 + s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed}) + continue + } + logger.Infof("任务提交成功:%+v", res) + // 更新任务 ID/频道 + job.TaskId = res.Result + job.MessageId = res.Result + job.ChannelId = res.Channel s.db.Updates(&job) - // 任务失败,通知前端 - s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed}) - continue } - logger.Infof("任务提交成功:%+v", res) - // 更新任务 ID/频道 - job.TaskId = res.Result - job.MessageId = res.Result - job.ChannelId = s.Name - s.db.Updates(&job) - } -} - -func (s *Service) Stop() { - s.running = false + }() } type CBReq struct { @@ -164,46 +150,6 @@ type CBReq struct { } `json:"properties"` } -func (s *Service) Notify(job model.MidJourneyJob) error { - task, err := s.Client.QueryTask(job.TaskId) - if err != nil { - return err - } - - // 任务执行失败了 - if task.FailReason != "" { - s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{ - "progress": failedProgress, - "err_msg": task.FailReason, - }) - s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed}) - return fmt.Errorf("task failed: %v", task.FailReason) - } - - if len(task.Buttons) > 0 { - job.Hash = GetImageHash(task.Buttons[0].CustomId) - } - oldProgress := job.Progress - job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0) - job.Prompt = task.PromptEn - if task.ImageUrl != "" { - job.OrgURL = task.ImageUrl - } - tx := s.db.Updates(&job) - if tx.Error != nil { - return fmt.Errorf("error with update database: %v", tx.Error) - } - // 通知前端更新任务进度 - if oldProgress != job.Progress { - message := sd.Running - if job.Progress == 100 { - message = sd.Finished - } - s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message}) - } - return nil -} - func GetImageHash(action string) string { split := strings.Split(action, "::") if len(split) > 5 { @@ -211,3 +157,143 @@ func GetImageHash(action string) string { } return split[len(split)-1] } + +func (s *Service) CheckTaskNotify() { + go func() { + for { + var message service.NotifyMessage + err := s.notifyQueue.LPop(&message) + if err != nil { + continue + } + cli := s.Clients.Get(uint(message.UserId)) + if cli == nil { + continue + } + err = cli.Send([]byte(message.Message)) + if err != nil { + continue + } + } + }() +} + +func (s *Service) DownloadImages() { + go func() { + var items []model.MidJourneyJob + for { + res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items) + if res.Error != nil { + continue + } + + // download images + for _, v := range items { + if v.OrgURL == "" { + continue + } + + logger.Infof("try to download image: %s", v.OrgURL) + // 如果是返回的是 discord 图片地址,则使用代理下载 + proxy := false + if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") { + proxy = true + } + imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy) + + if err != nil { + logger.Errorf("error with download image %s, %v", v.OrgURL, err) + continue + } else { + logger.Infof("download image %s successfully.", v.OrgURL) + } + + v.ImgURL = imgURL + s.db.Updates(&v) + + cli := s.Clients.Get(uint(v.UserId)) + if cli == nil { + continue + } + err = cli.Send([]byte(service.TaskStatusFinished)) + if err != nil { + continue + } + } + + time.Sleep(time.Second * 5) + } + }() +} + +// PushTask push a new mj task in to task queue +func (s *Service) PushTask(task types.MjTask) { + logger.Debugf("add a new MidJourney task to the task list: %+v", task) + s.taskQueue.RPush(task) +} + +// SyncTaskProgress 异步拉取任务 +func (s *Service) SyncTaskProgress() { + go func() { + var jobs []model.MidJourneyJob + for { + res := s.db.Where("progress < ?", 100).Where("channel_id <> ?", "").Find(&jobs) + if res.Error != nil { + continue + } + + for _, job := range jobs { + // 10 分钟还没完成的任务标记为失败 + if time.Now().Sub(job.CreatedAt) > time.Minute*10 { + job.Progress = service.FailTaskProgress + job.ErrMsg = "任务超时" + s.db.Updates(&job) + continue + } + + task, err := s.client.QueryTask(job.TaskId, job.ChannelId) + if err != nil { + logger.Errorf("error with query task: %v", err) + continue + } + + // 任务执行失败了 + if task.FailReason != "" { + s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{ + "progress": service.FailTaskProgress, + "err_msg": task.FailReason, + }) + logger.Errorf("task failed: %v", task.FailReason) + s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed}) + continue + } + + if len(task.Buttons) > 0 { + job.Hash = GetImageHash(task.Buttons[0].CustomId) + } + oldProgress := job.Progress + job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0) + job.Prompt = task.PromptEn + if task.ImageUrl != "" { + job.OrgURL = task.ImageUrl + } + err = s.db.Updates(&job).Error + if err != nil { + logger.Errorf("error with update database: %v", err) + continue + } + + // 通知前端更新任务进度 + if oldProgress != job.Progress { + message := service.TaskStatusRunning + if job.Progress == 100 { + message = service.TaskStatusFinished + } + s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message}) + } + } + + time.Sleep(time.Second * 5) + } + }() +} diff --git a/api/service/sd/pool.go b/api/service/sd/pool.go index 1d70a64a..d0033f67 100644 --- a/api/service/sd/pool.go +++ b/api/service/sd/pool.go @@ -10,6 +10,7 @@ package sd import ( "fmt" "geekai/core/types" + "geekai/service" "geekai/service/oss" "geekai/store" "geekai/store/model" @@ -79,7 +80,7 @@ func (p *ServicePool) CheckTaskNotify() { go func() { logger.Info("Running Stable-Diffusion task notify checking ...") for { - var message NotifyMessage + var message service.NotifyMessage err := p.notifyQueue.LPop(&message) if err != nil { continue diff --git a/api/service/sd/service.go b/api/service/sd/service.go index a9d707c7..d3b6c231 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -10,6 +10,7 @@ package sd import ( "fmt" "geekai/core/types" + logger2 "geekai/logger" "geekai/service" "geekai/service/oss" "geekai/store" @@ -22,6 +23,8 @@ import ( "gorm.io/gorm" ) +var logger = logger2.GetLogger() + // SD 绘画服务 type Service struct { @@ -87,11 +90,11 @@ func (s *Service) Run() { logger.Error("绘画任务执行失败:", err.Error()) // update the task progress s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{ - "progress": 101, + "progress": service.FailTaskProgress, "err_msg": err.Error(), }) // 通知前端,任务失败 - s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed}) + s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed}) continue } } @@ -206,7 +209,7 @@ func (s *Service) Txt2Img(task types.SdTask) error { // task finished s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) - s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished}) + s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished}) // 从 leveldb 中删除预览图片数据 _ = s.leveldb.Delete(task.Params.TaskId) return nil @@ -216,7 +219,7 @@ func (s *Service) Txt2Img(task types.SdTask) error { if err == nil && resp.Progress > 0 { s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) // 发送更新状态信号 - s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running}) + s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning}) // 保存预览图片数据 if resp.CurrentImage != "" { _ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage) diff --git a/api/service/sd/types.go b/api/service/sd/types.go deleted file mode 100644 index efdb970a..00000000 --- a/api/service/sd/types.go +++ /dev/null @@ -1,24 +0,0 @@ -package sd - -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// * Copyright 2023 The Geek-AI Authors. All rights reserved. -// * Use of this source code is governed by a Apache-2.0 license -// * that can be found in the LICENSE file. -// * @Author yangjian102621@163.com -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -import logger2 "geekai/logger" - -var logger = logger2.GetLogger() - -type NotifyMessage struct { - UserId int `json:"user_id"` - JobId int `json:"job_id"` - Message string `json:"message"` -} - -const ( - Running = "RUNNING" - Finished = "FINISH" - Failed = "FAIL" -) diff --git a/api/service/suno/service.go b/api/service/suno/service.go index 8f4defe5..b9599303 100644 --- a/api/service/suno/service.go +++ b/api/service/suno/service.go @@ -13,8 +13,8 @@ import ( "fmt" "geekai/core/types" logger2 "geekai/logger" + "geekai/service" "geekai/service/oss" - "geekai/service/sd" "geekai/store" "geekai/store/model" "geekai/utils" @@ -88,7 +88,7 @@ func (s *Service) Run() { logger.Errorf("create task with error: %v", err) s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ "err_msg": err.Error(), - "progress": 101, + "progress": service.FailTaskProgress, }) continue } @@ -157,6 +157,9 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) { if res.Code != "success" { return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message) } + // update the last_use_at for api key + apiKey.LastUsedAt = time.Now().Unix() + session.Updates(&apiKey) res.Channel = apiKey.ApiURL return res, nil } @@ -165,7 +168,7 @@ func (s *Service) CheckTaskNotify() { go func() { logger.Info("Running Suno task notify checking ...") for { - var message sd.NotifyMessage + var message service.NotifyMessage err := s.notifyQueue.LPop(&message) if err != nil { continue @@ -210,7 +213,7 @@ func (s *Service) DownloadImages() { v.AudioURL = audioURL v.Progress = 100 s.db.Updates(&v) - s.notifyQueue.RPush(sd.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: sd.Finished}) + s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished}) } time.Sleep(time.Second * 10) @@ -278,10 +281,10 @@ func (s *Service) SyncTaskProgress() { tx.Commit() } else if task.Data.FailReason != "" { - job.Progress = 101 + job.Progress = service.FailTaskProgress job.ErrMsg = task.Data.FailReason s.db.Updates(&job) - s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed}) + s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed}) } } diff --git a/api/service/types.go b/api/service/types.go index 15a538a2..70f8eb92 100644 --- a/api/service/types.go +++ b/api/service/types.go @@ -1,4 +1,17 @@ package service +const FailTaskProgress = 101 +const ( + TaskStatusRunning = "RUNNING" + TaskStatusFinished = "FINISH" + TaskStatusFailed = "FAIL" +) + +type NotifyMessage struct { + UserId int `json:"user_id"` + JobId int `json:"job_id"` + Message string `json:"message"` +} + const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]" const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" diff --git a/database/update-v4.1.2.sql b/database/update-v4.1.2.sql index 301ccae5..40acbd67 100644 --- a/database/update-v4.1.2.sql +++ b/database/update-v4.1.2.sql @@ -1 +1,2 @@ -ALTER TABLE `chatgpt_suno_jobs` MODIFY `id` INT AUTO_INCREMENT; \ No newline at end of file +ALTER TABLE `chatgpt_suno_jobs` MODIFY `id` INT AUTO_INCREMENT; +ALTER TABLE `chatgpt_mj_jobs` CHANGE `channel_id` `channel_id` VARCHAR(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NULL DEFAULT NULL COMMENT '频道ID'; \ No newline at end of file diff --git a/deploy/conf/config.toml b/deploy/conf/config.toml index 2b834a43..6c16cb26 100644 --- a/deploy/conf/config.toml +++ b/deploy/conf/config.toml @@ -69,18 +69,6 @@ TikaHost = "http://tika:9998" SubDir = "" Domain = "" -[[MjProxyConfigs]] - Enabled = false - ApiURL = "http://geekai-midjourney-proxy:8082" - Mode = "" - ApiKey = "sk-geekmaster" - -[[MjPlusConfigs]] - Enabled = false - ApiURL = "https://api.chat-plus.net" - Mode = "fast" - ApiKey = "sk-xxx" - [[SdConfigs]] Enabled = false Model = "" diff --git a/web/src/components/ChatReply.vue b/web/src/components/ChatReply.vue index 74b15760..9826f2b7 100644 --- a/web/src/components/ChatReply.vue +++ b/web/src/components/ChatReply.vue @@ -6,7 +6,7 @@