diff --git a/CHANGELOG.md b/CHANGELOG.md index 419a93bb..68e82937 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,13 @@ # 更新日志 +## v4.0.1 +* 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口,SDAPI 兼容各种 stable-diffusion 发行版,稳定性更强一些 +* 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API,兼容 MJ-Plus 中转 +* 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力 +* 功能新增:支持前端菜单可以配置 +* 功能优化:手机端支持免登录预览功能 +* 功能新增:手机端支持 Stable-Diffusion 绘画 +* Bug修复:修复 iphone 手机无法通过图形验证码的Bug,使用滑动验证码替换 + ## v4.0.0 非兼容版本,重大重构,引入算力概念,将系统中所有的能力(AI对话,MJ绘画,SD绘画,DALL绘画)全部使用算力来兑换。 只要你的算力值余额不为0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ 对话消耗15个算力... diff --git a/api/core/types/config.go b/api/core/types/config.go index 612d7ddc..a0a6b62b 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -5,22 +5,22 @@ import ( ) type AppConfig struct { - Path string `toml:"-"` - Listen string - Session Session - AdminSession Session - ProxyURL string - MysqlDns string // mysql 连接地址 - StaticDir string // 静态资源目录 - StaticUrl string // 静态资源 URL - Redis RedisConfig // redis 连接信息 - ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs - SMS SMSConfig // send mobile message config - OSS OSSConfig // OSS config - MjConfigs []MidJourneyConfig // mj AI draw service pool - MjPlusConfigs []MidJourneyPlusConfig // MJ plus config - WeChatBot bool // 是否启用微信机器人 - SdConfigs []StableDiffusionConfig // sd AI draw service pool + Path string `toml:"-"` + Listen string + Session Session + AdminSession Session + ProxyURL string + MysqlDns string // mysql 连接地址 + StaticDir string // 静态资源目录 + StaticUrl string // 静态资源 URL + Redis RedisConfig // redis 连接信息 + ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs + SMS SMSConfig // send mobile message config + OSS OSSConfig // OSS config + MjProxyConfigs []MjProxyConfig // MJ proxy config + MjPlusConfigs []MjPlusConfig // MJ plus config + WeChatBot bool // 是否启用微信机器人 + SdConfigs []StableDiffusionConfig // sd AI draw service pool XXLConfig XXLConfig AlipayConfig AlipayConfig @@ -43,16 +43,11 @@ type ChatPlusApiConfig struct { Token string } -type MidJourneyConfig struct { - Enabled bool - UserToken string - BotToken string - GuildId string // Server ID - ChanelId string // Chanel ID - UseCDN bool - ImgCdnURL string // 图片反代加速地址 - DiscordAPI string - DiscordGateway string +type MjProxyConfig struct { + Enabled bool + ApiURL string // api 地址 + Mode string // 绘画模式,可选值:fast/turbo/relax + ApiKey string } type StableDiffusionConfig struct { @@ -62,12 +57,11 @@ type StableDiffusionConfig struct { ApiKey string } -type MidJourneyPlusConfig struct { - Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务 - ApiURL string // api 地址 - Mode string // 绘画模式,可选值:fast/turbo/relax - ApiKey string - NotifyURL string // 任务进度更新回调地址 +type MjPlusConfig struct { + Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务 + ApiURL string // api 地址 + Mode string // 绘画模式,可选值:fast/turbo/relax + ApiKey string } type AlipayConfig struct { diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 82323584..52526ad2 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -5,7 +5,6 @@ import ( "chatplus/core/types" "chatplus/service" "chatplus/service/mj" - "chatplus/service/mj/plus" "chatplus/service/oss" "chatplus/store/model" "chatplus/store/vo" @@ -454,27 +453,6 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) { resp.SUCCESS(c) } -// Notify MidJourney Plus 服务任务回调处理 -func (h *MidJourneyHandler) Notify(c *gin.Context) { - var data plus.CBReq - if err := c.ShouldBindJSON(&data); err != nil { - logger.Error("非法任务回调:%+v", err) - return - } - err := h.pool.Notify(data) - if err != nil { - logger.Error(err) - } else { - userId := h.GetLoginUserId(c) - client := h.pool.Clients.Get(userId) - if client != nil { - _ = client.Send([]byte("Task Updated")) - } - } - - resp.SUCCESS(c) -} - // Publish 发布图片到画廊显示 func (h *MidJourneyHandler) Publish(c *gin.Context) { var data struct { diff --git a/api/main.go b/api/main.go index d070f341..ff43cdf7 100644 --- a/api/main.go +++ b/api/main.go @@ -252,7 +252,6 @@ func main() { group.GET("jobs", h.JobList) group.GET("imgWall", h.ImgWall) group.POST("remove", h.Remove) - group.POST("notify", h.Notify) group.POST("publish", h.Publish) }), fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) { diff --git a/api/service/mj/bot.go b/api/service/mj/bot.go deleted file mode 100644 index 14ee8368..00000000 --- a/api/service/mj/bot.go +++ /dev/null @@ -1,233 +0,0 @@ -package mj - -import ( - "chatplus/core/types" - logger2 "chatplus/logger" - "chatplus/utils" - discordgo "github.com/bg5t/mydiscordgo" - "github.com/gorilla/websocket" - "net/http" - "net/url" - "regexp" - "strings" -) - -// MidJourney 机器人 - -var logger = logger2.GetLogger() - -type Bot struct { - config types.MidJourneyConfig - bot *discordgo.Session - name string - service *Service -} - -func NewBot(name string, proxy string, config types.MidJourneyConfig, service *Service) (*Bot, error) { - bot, err := discordgo.New("Bot " + config.BotToken) - if err != nil { - logger.Error(err) - return nil, err - } - - // use CDN reverse proxy - if config.UseCDN { - discordgo.SetEndpointDiscord(config.DiscordAPI) - discordgo.SetEndpointCDN("https://cdn.discordapp.com") - discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/") - bot.MjGateway = config.DiscordGateway + "/" - } else { // use proxy - discordgo.SetEndpointDiscord("https://discord.com") - discordgo.SetEndpointCDN("https://cdn.discordapp.com") - discordgo.SetEndpointStatus("https://discord.com/api/v2/") - bot.MjGateway = "wss://gateway.discord.gg" - - if proxy != "" { - proxy, _ := url.Parse(proxy) - bot.Client = &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyURL(proxy), - }, - } - bot.Dialer = &websocket.Dialer{ - Proxy: http.ProxyURL(proxy), - } - } - - } - - return &Bot{ - config: config, - bot: bot, - name: name, - service: service, - }, nil -} - -func (b *Bot) Run() error { - b.bot.Identify.Intents = discordgo.IntentsAllWithoutPrivileged | discordgo.IntentsGuildMessages | discordgo.IntentMessageContent - b.bot.AddHandler(b.messageCreate) - b.bot.AddHandler(b.messageUpdate) - - logger.Infof("Starting MidJourney %s", b.name) - err := b.bot.Open() - if err != nil { - logger.Errorf("Error opening Discord connection for %s, error: %v", b.name, err) - return err - } - logger.Infof("Starting MidJourney %s successfully!", b.name) - return nil -} - -type TaskStatus string - -const ( - Start = TaskStatus("Started") - Running = TaskStatus("Running") - Stopped = TaskStatus("Stopped") - Finished = TaskStatus("Finished") -) - -type Image struct { - URL string `json:"url"` - ProxyURL string `json:"proxy_url"` - Filename string `json:"filename"` - Width int `json:"width"` - Height int `json:"height"` - Size int `json:"size"` - Hash string `json:"hash"` -} - -func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) { - // ignore messages for other channels - if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId { - return - } - // ignore messages for self - if m.Author == nil || m.Author.ID == s.State.User.ID { - return - } - - logger.Debugf("CREATE: %s", utils.JsonEncode(m)) - var referenceId = "" - if m.ReferencedMessage != nil { - referenceId = m.ReferencedMessage.ID - } - if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") { - // parse content - req := CBReq{ - ChannelId: m.ChannelID, - MessageId: m.ID, - ReferenceId: referenceId, - Prompt: extractPrompt(m.Content), - Content: m.Content, - Progress: 0, - Status: Start} - b.service.Notify(req) - return - } - - b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments) -} - -func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) { - // ignore messages for other channels - if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId { - return - } - // ignore messages for self - if m.Author == nil || m.Author.ID == s.State.User.ID { - return - } - - logger.Debugf("UPDATE: %s", utils.JsonEncode(m)) - - var referenceId = "" - if m.ReferencedMessage != nil { - referenceId = m.ReferencedMessage.ID - } - if strings.Contains(m.Content, "(Stopped)") { - req := CBReq{ - ChannelId: m.ChannelID, - MessageId: m.ID, - ReferenceId: referenceId, - Prompt: extractPrompt(m.Content), - Content: m.Content, - Progress: extractProgress(m.Content), - Status: Stopped} - b.service.Notify(req) - return - } - - b.addAttachment(m.ChannelID, m.ID, referenceId, m.Content, m.Attachments) - -} - -func (b *Bot) addAttachment(channelId string, messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) { - progress := extractProgress(content) - var status TaskStatus - if progress == 100 { - status = Finished - } else { - status = Running - } - for _, attachment := range attachments { - if attachment.Width == 0 || attachment.Height == 0 { - continue - } - image := Image{ - URL: attachment.URL, - Height: attachment.Height, - ProxyURL: attachment.ProxyURL, - Width: attachment.Width, - Size: attachment.Size, - Filename: attachment.Filename, - Hash: extractHashFromFilename(attachment.Filename), - } - req := CBReq{ - ChannelId: channelId, - MessageId: messageId, - ReferenceId: referenceId, - Image: image, - Prompt: extractPrompt(content), - Content: content, - Progress: progress, - Status: status, - } - b.service.Notify(req) - break // only get one image - } -} - -// extract prompt from string -func extractPrompt(input string) string { - pattern := `\*\*(.*?)\*\*` - re := regexp.MustCompile(pattern) - matches := re.FindStringSubmatch(input) - if len(matches) > 1 { - return strings.TrimSpace(matches[1]) - } - return "" -} - -func extractProgress(input string) int { - pattern := `\((\d+)\%\)` - re := regexp.MustCompile(pattern) - matches := re.FindStringSubmatch(input) - if len(matches) > 1 { - return utils.IntValue(matches[1], 0) - } - return 100 -} - -func extractHashFromFilename(filename string) string { - if !strings.HasSuffix(filename, ".png") { - return "" - } - - index := strings.LastIndex(filename, "_") - if index != -1 { - return filename[index+1 : len(filename)-4] - } - return "" -} diff --git a/api/service/mj/client.go b/api/service/mj/client.go index eada7586..2b71b007 100644 --- a/api/service/mj/client.go +++ b/api/service/mj/client.go @@ -1,159 +1,61 @@ package mj -import ( - "chatplus/core/types" - "errors" - "fmt" - "time" +import "chatplus/core/types" - "github.com/imroc/req/v3" -) - -// MidJourney client - -type Client struct { - client *req.Client - Config types.MidJourneyConfig - apiURL string +type Client interface { + Imagine(task types.MjTask) (ImageRes, error) + Blend(task types.MjTask) (ImageRes, error) + SwapFace(task types.MjTask) (ImageRes, error) + Upscale(task types.MjTask) (ImageRes, error) + Variation(task types.MjTask) (ImageRes, error) + QueryTask(taskId string) (QueryRes, error) } -func NewClient(config types.MidJourneyConfig, proxy string) *Client { - client := req.C().SetTimeout(10 * time.Second) - var apiURL string - // set proxy URL - if config.UseCDN { - apiURL = config.DiscordAPI + "/api/v9/interactions" - } else { - apiURL = "https://discord.com/api/v9/interactions" - if proxy != "" { - client.SetProxyURL(proxy) - } - } - - return &Client{client: client, Config: config, apiURL: apiURL} +type ImageReq struct { + BotType string `json:"botType,omitempty"` + Prompt string `json:"prompt,omitempty"` + Dimensions string `json:"dimensions,omitempty"` + Base64Array []string `json:"base64Array,omitempty"` + AccountFilter interface{} `json:"accountFilter,omitempty"` + NotifyHook string `json:"notifyHook,omitempty"` + State string `json:"state,omitempty"` } -func (c *Client) Imagine(task types.MjTask) error { - interactionsReq := &InteractionsRequest{ - Type: 2, - ApplicationID: ApplicationID, - GuildID: c.Config.GuildId, - ChannelID: c.Config.ChanelId, - SessionID: SessionID, - Data: map[string]any{ - "version": "1166847114203123795", - "id": "938956540159881230", - "name": "imagine", - "type": "1", - "options": []map[string]any{ - { - "type": 3, - "name": "prompt", - "value": fmt.Sprintf("%s %s", task.TaskId, task.Prompt), - }, - }, - "application_command": map[string]any{ - "id": "938956540159881230", - "application_id": ApplicationID, - "version": "1118961510123847772", - "default_permission": true, - "default_member_permissions": nil, - "type": 1, - "nsfw": false, - "name": "imagine", - "description": "Create images with Midjourney", - "dm_permission": true, - "options": []map[string]any{ - { - "type": 3, - "name": "prompt", - "description": "The prompt to imagine", - "required": true, - }, - }, - "attachments": []any{}, - }, - }, - } - - r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken). - SetHeader("Content-Type", "application/json"). - SetBody(interactionsReq). - Post(c.apiURL) - - if err != nil || r.IsErrorState() { - return fmt.Errorf("error with http request: %w%v", err, r.Err) - } - - return nil +type ImageRes struct { + Code int `json:"code"` + Description string `json:"description"` + Properties struct { + } `json:"properties"` + Result string `json:"result"` } -func (c *Client) Blend(task types.MjTask) error { - return errors.New("function not implemented") +type ErrRes struct { + Error struct { + Message string `json:"message"` + } `json:"error"` } -func (c *Client) SwapFace(task types.MjTask) error { - return errors.New("function not implemented") -} - -// Upscale 放大指定的图片 -func (c *Client) Upscale(task types.MjTask) error { - flags := 0 - interactionsReq := &InteractionsRequest{ - Type: 3, - ApplicationID: ApplicationID, - GuildID: c.Config.GuildId, - ChannelID: c.Config.ChanelId, - MessageFlags: flags, - MessageID: task.MessageId, - SessionID: SessionID, - Data: map[string]any{ - "component_type": 2, - "custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), - }, - Nonce: fmt.Sprintf("%d", time.Now().UnixNano()), - } - - var res InteractionsResult - r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken). - SetHeader("Content-Type", "application/json"). - SetBody(interactionsReq). - SetErrorResult(&res). - Post(c.apiURL) - if err != nil || r.IsErrorState() { - return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message) - } - - return nil -} - -// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 -func (c *Client) Variation(task types.MjTask) error { - flags := 0 - interactionsReq := &InteractionsRequest{ - Type: 3, - ApplicationID: ApplicationID, - GuildID: c.Config.GuildId, - ChannelID: c.Config.ChanelId, - MessageFlags: flags, - MessageID: task.MessageId, - SessionID: SessionID, - Data: map[string]any{ - "component_type": 2, - "custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), - }, - Nonce: fmt.Sprintf("%d", time.Now().UnixNano()), - } - - var res InteractionsResult - r, err := c.client.R().SetHeader("Authorization", c.Config.UserToken). - SetHeader("Content-Type", "application/json"). - SetBody(interactionsReq). - SetErrorResult(&res). - Post(c.apiURL) - if err != nil || r.IsErrorState() { - return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message) - } - - return nil +type QueryRes struct { + Action string `json:"action"` + Buttons []struct { + CustomId string `json:"customId"` + Emoji string `json:"emoji"` + Label string `json:"label"` + Style int `json:"style"` + Type int `json:"type"` + } `json:"buttons"` + Description string `json:"description"` + FailReason string `json:"failReason"` + FinishTime int `json:"finishTime"` + Id string `json:"id"` + ImageUrl string `json:"imageUrl"` + Progress string `json:"progress"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Properties struct { + } `json:"properties"` + StartTime int `json:"startTime"` + State string `json:"state"` + Status string `json:"status"` + SubmitTime int `json:"submitTime"` } diff --git a/api/service/mj/plus/service.go b/api/service/mj/plus/service.go deleted file mode 100644 index f02db6cb..00000000 --- a/api/service/mj/plus/service.go +++ /dev/null @@ -1,194 +0,0 @@ -package plus - -import ( - "chatplus/core/types" - "chatplus/store" - "chatplus/store/model" - "chatplus/utils" - "fmt" - "strings" - "sync/atomic" - "time" - - "gorm.io/gorm" -) - -// Service MJ 绘画服务 -type Service struct { - Name string // service Name - Client *Client // MJ Client - taskQueue *store.RedisQueue - notifyQueue *store.RedisQueue - db *gorm.DB - maxHandleTaskNum int32 // max task number current service can handle - HandledTaskNum int32 // already handled task number - taskStartTimes map[int]time.Time // task start time, to check if the task is timeout - taskTimeout int64 -} - -func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service { - return &Service{ - Name: name, - db: db, - taskQueue: taskQueue, - notifyQueue: notifyQueue, - Client: client, - taskTimeout: timeout, - maxHandleTaskNum: maxTaskNum, - taskStartTimes: make(map[int]time.Time, 0), - } -} - -func (s *Service) Run() { - logger.Infof("Starting MidJourney job consumer for %s", s.Name) - for { - s.checkTasks() - if !s.canHandleTask() { - // current service is full, can not handle more task - // waiting for running task finish - time.Sleep(time.Second * 3) - continue - } - - var task types.MjTask - err := s.taskQueue.LPop(&task) - if err != nil { - logger.Errorf("taking task with error: %v", err) - continue - } - - // if it's reference message, check if it's this channel's message - //if task.ChannelId != "" && task.ChannelId != s.Name { - // logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId) - // s.taskQueue.RPush(task) - // time.Sleep(time.Second) - // continue - //} - - logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task) - var res ImageRes - switch task.Type { - case types.TaskImage: - res, err = s.Client.Imagine(task) - break - case types.TaskUpscale: - res, err = s.Client.Upscale(task) - break - case types.TaskVariation: - res, err = s.Client.Variation(task) - break - case types.TaskBlend: - res, err = s.Client.Blend(task) - break - case types.TaskSwapFace: - res, err = s.Client.SwapFace(task) - break - } - - var job model.MidJourneyJob - s.db.Where("id = ?", task.Id).First(&job) - if err != nil || (res.Code != 1 && res.Code != 22) { - errMsg := fmt.Sprintf("%v,%s", err, res.Description) - logger.Error("绘画任务执行失败:", errMsg) - job.Progress = -1 - job.ErrMsg = errMsg - // update the task progress - s.db.Updates(&job) - // 任务失败,通知前端 - s.notifyQueue.RPush(task.UserId) - continue - } - logger.Infof("任务提交成功:%+v", res) - // lock the task until the execute timeout - s.taskStartTimes[int(task.Id)] = time.Now() - atomic.AddInt32(&s.HandledTaskNum, 1) - // 更新任务 ID/频道 - job.TaskId = res.Result - job.ChannelId = s.Name - s.db.Updates(&job) - } -} - -// check if current service instance can handle more task -func (s *Service) canHandleTask() bool { - handledNum := atomic.LoadInt32(&s.HandledTaskNum) - return handledNum < s.maxHandleTaskNum -} - -// remove the expired tasks -func (s *Service) checkTasks() { - for k, t := range s.taskStartTimes { - if time.Now().Unix()-t.Unix() > s.taskTimeout { - delete(s.taskStartTimes, k) - atomic.AddInt32(&s.HandledTaskNum, -1) - // delete task from database - s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100") - } - } -} - -type CBReq struct { - Id string `json:"id"` - Action string `json:"action"` - Status string `json:"status"` - Prompt string `json:"prompt"` - PromptEn string `json:"promptEn"` - Description string `json:"description"` - SubmitTime int64 `json:"submitTime"` - StartTime int64 `json:"startTime"` - FinishTime int64 `json:"finishTime"` - Progress string `json:"progress"` - ImageUrl string `json:"imageUrl"` - FailReason interface{} `json:"failReason"` - Properties struct { - FinalPrompt string `json:"finalPrompt"` - } `json:"properties"` -} - -func (s *Service) Notify(job model.MidJourneyJob) error { - task, err := s.Client.QueryTask(job.TaskId) - if err != nil { - return err - } - - // 任务执行失败了 - if task.FailReason != "" { - s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{ - "progress": -1, - "err_msg": task.FailReason, - }) - return fmt.Errorf("task failed: %v", task.FailReason) - } - - if len(task.Buttons) > 0 { - job.Hash = GetImageHash(task.Buttons[0].CustomId) - } - oldProgress := job.Progress - job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0) - job.Prompt = task.PromptEn - if task.ImageUrl != "" { - job.OrgURL = task.ImageUrl - } - job.MessageId = task.Id - tx := s.db.Updates(&job) - if tx.Error != nil { - return fmt.Errorf("error with update database: %v", tx.Error) - } - if task.Status == "SUCCESS" { - // release lock task - atomic.AddInt32(&s.HandledTaskNum, -1) - } - // 通知前端更新任务进度 - if oldProgress != job.Progress { - s.notifyQueue.RPush(job.UserId) - } - return nil -} - -func GetImageHash(action string) string { - split := strings.Split(action, "::") - if len(split) > 5 { - return split[4] - } - return split[len(split)-1] -} diff --git a/api/service/mj/plus/client.go b/api/service/mj/plus_client.go similarity index 64% rename from api/service/mj/plus/client.go rename to api/service/mj/plus_client.go index 757ebc96..38c3265e 100644 --- a/api/service/mj/plus/client.go +++ b/api/service/mj/plus_client.go @@ -1,8 +1,7 @@ -package plus +package mj import ( "chatplus/core/types" - logger2 "chatplus/logger" "chatplus/utils" "encoding/base64" "errors" @@ -13,53 +12,21 @@ import ( "github.com/gin-gonic/gin" ) -var logger = logger2.GetLogger() - -// Client MidJourney Plus Client -type Client struct { - Config types.MidJourneyPlusConfig +// PlusClient MidJourney Plus ProxyClient +type PlusClient struct { + Config types.MjPlusConfig apiURL string } -func NewClient(config types.MidJourneyPlusConfig) *Client { - return &Client{Config: config, apiURL: config.ApiURL} +func NewPlusClient(config types.MjPlusConfig) *PlusClient { + return &PlusClient{Config: config, apiURL: config.ApiURL} } -type ImageReq struct { - BotType string `json:"botType"` - Prompt string `json:"prompt,omitempty"` - Dimensions string `json:"dimensions,omitempty"` - Base64Array []string `json:"base64Array,omitempty"` - AccountFilter struct { - InstanceId string `json:"instanceId"` - Modes []interface{} `json:"modes"` - Remix bool `json:"remix"` - RemixAutoConsidered bool `json:"remixAutoConsidered"` - } `json:"accountFilter,omitempty"` - NotifyHook string `json:"notifyHook"` - State string `json:"state,omitempty"` -} - -type ImageRes struct { - Code int `json:"code"` - Description string `json:"description"` - Properties struct { - } `json:"properties"` - Result string `json:"result"` -} - -type ErrRes struct { - Error struct { - Message string `json:"message"` - } `json:"error"` -} - -func (c *Client) Imagine(task types.MjTask) (ImageRes, error) { +func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) { apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode) body := ImageReq{ BotType: "MID_JOURNEY", Prompt: task.Prompt, - NotifyHook: c.Config.NotifyURL, Base64Array: make([]string, 0), } // 生成图片 Base64 编码 @@ -94,12 +61,11 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) { } // Blend 融图 -func (c *Client) Blend(task types.MjTask) (ImageRes, error) { +func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) { apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode) body := ImageReq{ BotType: "MID_JOURNEY", Dimensions: "SQUARE", - NotifyHook: c.Config.NotifyURL, Base64Array: make([]string, 0), } // 生成图片 Base64 编码 @@ -133,7 +99,7 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) { } // SwapFace 换脸 -func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) { +func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) { apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode) // 生成图片 Base64 编码 if len(task.ImgArr) != 2 { @@ -160,8 +126,7 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) { "accountFilter": gin.H{ "instanceId": "", }, - "notifyHook": c.Config.NotifyURL, - "state": "", + "state": "", } var res ImageRes var errRes ErrRes @@ -183,11 +148,10 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) { } // Upscale 放大指定的图片 -func (c *Client) Upscale(task types.MjTask) (ImageRes, error) { +func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) { body := map[string]string{ - "customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), - "taskId": task.MessageId, - "notifyHook": c.Config.NotifyURL, + "customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), + "taskId": task.MessageId, } apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL) var res ImageRes @@ -210,11 +174,10 @@ func (c *Client) Upscale(task types.MjTask) (ImageRes, error) { } // Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 -func (c *Client) Variation(task types.MjTask) (ImageRes, error) { +func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) { body := map[string]string{ - "customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), - "taskId": task.MessageId, - "notifyHook": c.Config.NotifyURL, + "customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), + "taskId": task.MessageId, } apiURL := fmt.Sprintf("%s/mj/submit/action", c.apiURL) var res ImageRes @@ -236,32 +199,7 @@ func (c *Client) Variation(task types.MjTask) (ImageRes, error) { return res, nil } -type QueryRes struct { - Action string `json:"action"` - Buttons []struct { - CustomId string `json:"customId"` - Emoji string `json:"emoji"` - Label string `json:"label"` - Style int `json:"style"` - Type int `json:"type"` - } `json:"buttons"` - Description string `json:"description"` - FailReason string `json:"failReason"` - FinishTime int `json:"finishTime"` - Id string `json:"id"` - ImageUrl string `json:"imageUrl"` - Progress string `json:"progress"` - Prompt string `json:"prompt"` - PromptEn string `json:"promptEn"` - Properties struct { - } `json:"properties"` - StartTime int `json:"startTime"` - State string `json:"state"` - Status string `json:"status"` - SubmitTime int `json:"submitTime"` -} - -func (c *Client) QueryTask(taskId string) (QueryRes, error) { +func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) { apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId) var res QueryRes r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey). @@ -278,3 +216,5 @@ func (c *Client) QueryTask(taskId string) (QueryRes, error) { return res, nil } + +var _ Client = &PlusClient{} diff --git a/api/service/mj/pool.go b/api/service/mj/pool.go index fdf3428a..2cf23ac9 100644 --- a/api/service/mj/pool.go +++ b/api/service/mj/pool.go @@ -2,13 +2,12 @@ package mj import ( "chatplus/core/types" - "chatplus/service/mj/plus" + logger2 "chatplus/logger" "chatplus/service/oss" "chatplus/store" "chatplus/store/model" "fmt" "github.com/go-redis/redis/v8" - "strings" "time" "gorm.io/gorm" @@ -16,7 +15,7 @@ import ( // ServicePool Mj service pool type ServicePool struct { - services []interface{} + services []*Service taskQueue *store.RedisQueue notifyQueue *store.RedisQueue db *gorm.DB @@ -24,8 +23,10 @@ type ServicePool struct { Clients *types.LMap[uint, *types.WsClient] // UserId => Client } +var logger = logger2.GetLogger() + func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool { - services := make([]interface{}, 0) + services := make([]*Service, 0) taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli) notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli) @@ -33,45 +34,26 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa if config.Enabled == false { continue } - client := plus.NewClient(config) - name := fmt.Sprintf("mj-service-plus-%d", k) - servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client) + cli := NewPlusClient(config) + name := fmt.Sprintf("mj-plus-service-%d", k) + service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli) go func() { - servicePlus.Run() + service.Run() }() - services = append(services, servicePlus) + services = append(services, service) } - if len(services) == 0 { - // create mj client and service - for k, config := range appConfig.MjConfigs { - if config.Enabled == false { - continue - } - // create mj client - client := NewClient(config, appConfig.ProxyURL) - - name := fmt.Sprintf("MjService-%d", k) - // create mj service - service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client) - botName := fmt.Sprintf("MjBot-%d", k) - bot, err := NewBot(botName, appConfig.ProxyURL, config, service) - if err != nil { - continue - } - - err = bot.Run() - if err != nil { - continue - } - - // run mj service - go func() { - service.Run() - }() - - services = append(services, service) + for k, config := range appConfig.MjProxyConfigs { + if config.Enabled == false { + continue } + cli := NewProxyClient(config) + name := fmt.Sprintf("mj-proxy-service-%d", k) + service := NewService(name, taskQueue, notifyQueue, 4, 600, db, cli) + go func() { + service.Run() + }() + services = append(services, service) } return &ServicePool{ @@ -92,11 +74,11 @@ func (p *ServicePool) CheckTaskNotify() { if err != nil { continue } - client := p.Clients.Get(userId) - if client == nil { + cli := p.Clients.Get(userId) + if cli == nil { continue } - err = client.Send([]byte("Task Updated")) + err = cli.Send([]byte("Task Updated")) if err != nil { continue } @@ -122,10 +104,10 @@ func (p *ServicePool) DownloadImages() { logger.Infof("try to download image: %s", v.OrgURL) var imgURL string var err error - if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil { + if servicePlus := p.getService(v.ChannelId); servicePlus != nil { task, _ := servicePlus.Client.QueryTask(v.TaskId) if len(task.Buttons) > 0 { - v.Hash = plus.GetImageHash(task.Buttons[0].CustomId) + v.Hash = GetImageHash(task.Buttons[0].CustomId) } imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false) } else { @@ -141,11 +123,11 @@ func (p *ServicePool) DownloadImages() { v.ImgURL = imgURL p.db.Updates(&v) - client := p.Clients.Get(uint(v.UserId)) - if client == nil { + cli := p.Clients.Get(uint(v.UserId)) + if cli == nil { continue } - err = client.Send([]byte("Task Updated")) + err = cli.Send([]byte("Task Updated")) if err != nil { continue } @@ -167,25 +149,6 @@ func (p *ServicePool) HasAvailableService() bool { return len(p.services) > 0 } -func (p *ServicePool) Notify(data plus.CBReq) error { - logger.Debugf("收到任务回调:%+v", data) - var job model.MidJourneyJob - res := p.db.Where("task_id = ?", data.Id).First(&job) - if res.Error != nil { - return fmt.Errorf("非法任务:%s", data.Id) - } - - // 任务已经拉取完成 - if job.Progress == 100 { - return nil - } - if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil { - return servicePlus.Notify(job) - } - - return nil -} - // SyncTaskProgress 异步拉取任务 func (p *ServicePool) SyncTaskProgress() { go func() { @@ -222,11 +185,7 @@ func (p *ServicePool) SyncTaskProgress() { } } - if !strings.HasPrefix(job.ChannelId, "mj-service-plus") { - continue - } - - if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil { + if servicePlus := p.getService(job.ChannelId); servicePlus != nil { _ = servicePlus.Notify(job) } } @@ -236,12 +195,10 @@ func (p *ServicePool) SyncTaskProgress() { }() } -func (p *ServicePool) getServicePlus(name string) *plus.Service { +func (p *ServicePool) getService(name string) *Service { for _, s := range p.services { - if servicePlus, ok := s.(*plus.Service); ok { - if servicePlus.Name == name { - return servicePlus - } + if s.Name == name { + return s } } return nil diff --git a/api/service/mj/proxy_client.go b/api/service/mj/proxy_client.go new file mode 100644 index 00000000..c6f66c64 --- /dev/null +++ b/api/service/mj/proxy_client.go @@ -0,0 +1,176 @@ +package mj + +import ( + "chatplus/core/types" + "chatplus/utils" + "encoding/base64" + "errors" + "fmt" + "github.com/imroc/req/v3" + "io" +) + +// ProxyClient MidJourney Proxy Client +type ProxyClient struct { + Config types.MjProxyConfig + apiURL string +} + +func NewProxyClient(config types.MjProxyConfig) *ProxyClient { + return &ProxyClient{Config: config, apiURL: config.ApiURL} +} + +func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) { + apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL) + body := ImageReq{ + Prompt: task.Prompt, + Base64Array: make([]string, 0), + } + // 生成图片 Base64 编码 + if len(task.ImgArr) > 0 { + imageData, err := utils.DownloadImage(task.ImgArr[0], "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) + } + + } + logger.Info("API URL: ", apiURL) + var res ImageRes + var errRes ErrRes + r, err := req.C().R(). + SetHeader("mj-api-secret", c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + all, err := io.ReadAll(r.Body) + logger.Info(string(all)) + return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) + } + + if r.IsErrorState() { + errStr, _ := io.ReadAll(r.Body) + return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr)) + } + + return res, nil +} + +// Blend 融图 +func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) { + apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL) + body := ImageReq{ + Dimensions: "SQUARE", + Base64Array: make([]string, 0), + } + // 生成图片 Base64 编码 + if len(task.ImgArr) > 0 { + for _, imgURL := range task.ImgArr { + imageData, err := utils.DownloadImage(imgURL, "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) + } + } + } + var res ImageRes + var errRes ErrRes + r, err := req.C().R(). + SetHeader("mj-api-secret", c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + return res, nil +} + +// SwapFace 换脸 +func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) { + return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能,请使用 MidJourney-Plus") +} + +// Upscale 放大指定的图片 +func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) { + body := map[string]interface{}{ + "action": "UPSCALE", + "index": task.Index, + "taskId": task.MessageId, + } + apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL) + var res ImageRes + var errRes ErrRes + r, err := req.C().R(). + SetHeader("mj-api-secret", c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + return res, nil +} + +// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 +func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) { + body := map[string]interface{}{ + "action": "VARIATION", + "index": task.Index, + "taskId": task.MessageId, + } + apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL) + var res ImageRes + var errRes ErrRes + r, err := req.C().R(). + SetHeader("mj-api-secret", c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + return res, nil +} + +func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) { + apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId) + var res QueryRes + r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey). + SetSuccessResult(&res). + Get(apiURL) + + if err != nil { + return QueryRes{}, err + } + + if r.IsErrorState() { + return QueryRes{}, errors.New("error status:" + r.Status) + } + + return res, nil +} + +var _ Client = &ProxyClient{} diff --git a/api/service/mj/service.go b/api/service/mj/service.go index 154f70db..c739078e 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -16,24 +16,24 @@ import ( // Service MJ 绘画服务 type Service struct { - name string // service name - client *Client // MJ client + Name string // service Name + Client Client // MJ Client taskQueue *store.RedisQueue notifyQueue *store.RedisQueue db *gorm.DB maxHandleTaskNum int32 // max task number current service can handle - handledTaskNum int32 // already handled task number + HandledTaskNum int32 // already handled task number taskStartTimes map[int]time.Time // task start time, to check if the task is timeout taskTimeout int64 } -func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service { +func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, cli Client) *Service { return &Service{ - name: name, + Name: name, db: db, taskQueue: taskQueue, notifyQueue: notifyQueue, - client: client, + Client: cli, taskTimeout: timeout, maxHandleTaskNum: maxTaskNum, taskStartTimes: make(map[int]time.Time, 0), @@ -41,7 +41,7 @@ func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.Red } func (s *Service) Run() { - logger.Infof("Starting MidJourney job consumer for %s", s.name) + logger.Infof("Starting MidJourney job consumer for %s", s.Name) for { s.checkTasks() if !s.canHandleTask() { @@ -58,65 +58,72 @@ func (s *Service) Run() { continue } - // if it's reference message, check if it's this channel's message - if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId { + // 如果配置了多个中转平台的 API KEY + // U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表 + if task.ChannelId != "" && task.ChannelId != s.Name { + logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId) s.taskQueue.RPush(task) time.Sleep(time.Second) continue } // 翻译提示词 - if utils.HasChinese(task.Prompt) { + if utils.HasChinese(task.Prompt) && strings.HasPrefix(s.Name, "mj-proxy-service") { content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt)) if err == nil { task.Prompt = content + } else { + logger.Warnf("error with translate prompt: %v", err) } } - logger.Infof("%s handle a new MidJourney task: %+v", s.name, task) + logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task) + var res ImageRes switch task.Type { case types.TaskImage: - err = s.client.Imagine(task) + res, err = s.Client.Imagine(task) break case types.TaskUpscale: - err = s.client.Upscale(task) + res, err = s.Client.Upscale(task) break case types.TaskVariation: - err = s.client.Variation(task) + res, err = s.Client.Variation(task) break case types.TaskBlend: - err = s.client.Blend(task) + res, err = s.Client.Blend(task) break case types.TaskSwapFace: - err = s.client.SwapFace(task) + res, err = s.Client.SwapFace(task) break } - if err != nil { - logger.Error("绘画任务执行失败:", err.Error()) + var job model.MidJourneyJob + s.db.Where("id = ?", task.Id).First(&job) + if err != nil || (res.Code != 1 && res.Code != 22) { + errMsg := fmt.Sprintf("%v,%s", err, res.Description) + logger.Error("绘画任务执行失败:", errMsg) + job.Progress = -1 + job.ErrMsg = errMsg // update the task progress - s.db.Model(&model.MidJourneyJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ - "progress": -1, - "err_msg": err.Error(), - }) + s.db.Updates(&job) + // 任务失败,通知前端 s.notifyQueue.RPush(task.UserId) - // restore img_call quota - if task.Type.String() != types.TaskUpscale.String() { - s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1)) - } continue } - logger.Infof("Task Executed: %+v", task) + logger.Infof("任务提交成功:%+v", res) // lock the task until the execute timeout s.taskStartTimes[int(task.Id)] = time.Now() - atomic.AddInt32(&s.handledTaskNum, 1) - + atomic.AddInt32(&s.HandledTaskNum, 1) + // 更新任务 ID/频道 + job.TaskId = res.Result + job.ChannelId = s.Name + s.db.Updates(&job) } } // check if current service instance can handle more task func (s *Service) canHandleTask() bool { - handledNum := atomic.LoadInt32(&s.handledTaskNum) + handledNum := atomic.LoadInt32(&s.HandledTaskNum) return handledNum < s.maxHandleTaskNum } @@ -125,65 +132,75 @@ func (s *Service) checkTasks() { for k, t := range s.taskStartTimes { if time.Now().Unix()-t.Unix() > s.taskTimeout { delete(s.taskStartTimes, k) - atomic.AddInt32(&s.handledTaskNum, -1) + atomic.AddInt32(&s.HandledTaskNum, -1) // delete task from database s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100") } } } -func (s *Service) Notify(data CBReq) { - // extract the task ID - split := strings.Split(data.Prompt, " ") - var job model.MidJourneyJob - res := s.db.Where("message_id = ?", data.MessageId).First(&job) - if res.Error == nil && data.Status == Finished { - logger.Warn("重复消息:", data.MessageId) - return - } - - tx := s.db.Session(&gorm.Session{}).Where("progress < ?", 100).Order("id ASC") - if data.ReferenceId != "" { - tx = tx.Where("reference_id = ?", data.ReferenceId) - } else { - tx = tx.Where("task_id = ?", split[0]) - } - // fixed: 修复 U/V 操作任务混淆覆盖的 Bug - if strings.Contains(data.Prompt, "** - Image #") { // for upscale - tx = tx.Where("type = ?", types.TaskUpscale.String()) - } else if strings.Contains(data.Prompt, "** - Variations (Strong)") { // for Variations - tx = tx.Where("type = ?", types.TaskVariation.String()) - } - res = tx.First(&job) - if res.Error != nil { - logger.Warn("非法任务:", res.Error) - return - } - - job.ChannelId = data.ChannelId - job.MessageId = data.MessageId - job.ReferenceId = data.ReferenceId - job.Progress = data.Progress - job.Prompt = data.Prompt - job.Hash = data.Image.Hash - if s.client.Config.UseCDN { - job.UseProxy = true - job.OrgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.ImgCdnURL) - } else { - job.OrgURL = data.Image.URL - } - - res = s.db.Updates(&job) - if res.Error != nil { - logger.Error("error with update job: ", res.Error) - return - } - - if data.Status == Finished { - // release lock task - atomic.AddInt32(&s.handledTaskNum, -1) - } - - s.notifyQueue.RPush(job.UserId) - +type CBReq struct { + Id string `json:"id"` + Action string `json:"action"` + Status string `json:"status"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Description string `json:"description"` + SubmitTime int64 `json:"submitTime"` + StartTime int64 `json:"startTime"` + FinishTime int64 `json:"finishTime"` + Progress string `json:"progress"` + ImageUrl string `json:"imageUrl"` + FailReason interface{} `json:"failReason"` + Properties struct { + FinalPrompt string `json:"finalPrompt"` + } `json:"properties"` +} + +func (s *Service) Notify(job model.MidJourneyJob) error { + task, err := s.Client.QueryTask(job.TaskId) + if err != nil { + return err + } + + // 任务执行失败了 + if task.FailReason != "" { + s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{ + "progress": -1, + "err_msg": task.FailReason, + }) + return fmt.Errorf("task failed: %v", task.FailReason) + } + + if len(task.Buttons) > 0 { + job.Hash = GetImageHash(task.Buttons[0].CustomId) + } + oldProgress := job.Progress + job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0) + job.Prompt = task.PromptEn + if task.ImageUrl != "" { + job.OrgURL = task.ImageUrl + } + job.MessageId = task.Id + tx := s.db.Updates(&job) + if tx.Error != nil { + return fmt.Errorf("error with update database: %v", tx.Error) + } + if task.Status == "SUCCESS" { + // release lock task + atomic.AddInt32(&s.HandledTaskNum, -1) + } + // 通知前端更新任务进度 + if oldProgress != job.Progress { + s.notifyQueue.RPush(job.UserId) + } + return nil +} + +func GetImageHash(action string) string { + split := strings.Split(action, "::") + if len(split) > 5 { + return split[4] + } + return split[len(split)-1] } diff --git a/api/service/mj/types.go b/api/service/mj/types.go deleted file mode 100644 index ff6a5dd3..00000000 --- a/api/service/mj/types.go +++ /dev/null @@ -1,35 +0,0 @@ -package mj - -const ( - ApplicationID string = "936929561302675456" - SessionID string = "ea8816d857ba9ae2f74c59ae1a953afe" -) - -type InteractionsRequest struct { - Type int `json:"type"` - ApplicationID string `json:"application_id"` - MessageFlags int `json:"message_flags,omitempty"` - MessageID string `json:"message_id,omitempty"` - GuildID string `json:"guild_id"` - ChannelID string `json:"channel_id"` - SessionID string `json:"session_id"` - Data map[string]any `json:"data"` - Nonce string `json:"nonce,omitempty"` -} - -type InteractionsResult struct { - Code int `json:"code"` - Message string - Error map[string]any -} - -type CBReq struct { - ChannelId string `json:"channel_id"` - MessageId string `json:"message_id"` - ReferenceId string `json:"reference_id"` - Image Image `json:"image"` - Content string `json:"content"` - Prompt string `json:"prompt"` - Status TaskStatus `json:"status"` - Progress int `json:"progress"` -} diff --git a/api/test/test.go b/api/test/test.go index 79058077..14362448 100644 --- a/api/test/test.go +++ b/api/test/test.go @@ -1,5 +1,11 @@ package main -func main() { +import ( + "chatplus/utils" + "fmt" +) +func main() { + text := "一只 蜗牛在树干上爬,阳光透过树叶照在蜗牛的背上 --ar 1:1 --iw 0.250000 --v 6" + fmt.Println(utils.HasChinese(text)) }