mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	feat: refactor MidJourney service for conpatible drawing in chat and draw in app
This commit is contained in:
		@@ -49,15 +49,6 @@ type ChatModel struct {
 | 
				
			|||||||
	Value    string   `json:"value"`
 | 
						Value    string   `json:"value"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type MjTask struct {
 | 
					 | 
				
			||||||
	ChatId      string
 | 
					 | 
				
			||||||
	MessageId   string
 | 
					 | 
				
			||||||
	MessageHash string
 | 
					 | 
				
			||||||
	UserId      uint
 | 
					 | 
				
			||||||
	RoleId      uint
 | 
					 | 
				
			||||||
	Icon        string
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ApiError struct {
 | 
					type ApiError struct {
 | 
				
			||||||
	Error struct {
 | 
						Error struct {
 | 
				
			||||||
		Message string
 | 
							Message string
 | 
				
			||||||
@@ -77,5 +68,3 @@ var ModelToTokens = map[string]int{
 | 
				
			|||||||
	"gpt-4":             8192,
 | 
						"gpt-4":             8192,
 | 
				
			||||||
	"gpt-4-32k":         32768,
 | 
						"gpt-4-32k":         32768,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
const TaskStorePrefix = "/tasks/"
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -131,6 +131,13 @@ func (h *ChatHandler) sendAzureMessage(
 | 
				
			|||||||
				utils.ReplyMessage(ws, "")
 | 
									utils.ReplyMessage(ws, "")
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				f := h.App.Functions[functionName]
 | 
									f := h.App.Functions[functionName]
 | 
				
			||||||
 | 
									if functionName == types.FuncMidJourney {
 | 
				
			||||||
 | 
										params["user_id"] = userVo.Id
 | 
				
			||||||
 | 
										params["role_id"] = role.Id
 | 
				
			||||||
 | 
										params["chat_id"] = session.ChatId
 | 
				
			||||||
 | 
										params["icon"] = "/images/avatar/mid_journey.png"
 | 
				
			||||||
 | 
										params["session_id"] = session.SessionId
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
				data, err := f.Invoke(params)
 | 
									data, err := f.Invoke(params)
 | 
				
			||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					msg := "调用函数出错:" + err.Error()
 | 
										msg := "调用函数出错:" + err.Error()
 | 
				
			||||||
@@ -142,22 +149,8 @@ func (h *ChatHandler) sendAzureMessage(
 | 
				
			|||||||
				} else {
 | 
									} else {
 | 
				
			||||||
					content := data
 | 
										content := data
 | 
				
			||||||
					if functionName == types.FuncMidJourney {
 | 
										if functionName == types.FuncMidJourney {
 | 
				
			||||||
						key := utils.Sha256(data)
 | 
					 | 
				
			||||||
						logger.Debug(data, ",", key)
 | 
					 | 
				
			||||||
						// add task for MidJourney
 | 
					 | 
				
			||||||
						h.App.MjTaskClients.Put(key, ws)
 | 
					 | 
				
			||||||
						task := types.MjTask{
 | 
					 | 
				
			||||||
							UserId: userVo.Id,
 | 
					 | 
				
			||||||
							RoleId: role.Id,
 | 
					 | 
				
			||||||
							Icon:   "/images/avatar/mid_journey.png",
 | 
					 | 
				
			||||||
							ChatId: session.ChatId,
 | 
					 | 
				
			||||||
						}
 | 
					 | 
				
			||||||
						err := h.leveldb.Put(types.TaskStorePrefix+key, task)
 | 
					 | 
				
			||||||
						if err != nil {
 | 
					 | 
				
			||||||
							logger.Error("error with store MidJourney task: ", err)
 | 
					 | 
				
			||||||
						}
 | 
					 | 
				
			||||||
						content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
 | 
											content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
 | 
				
			||||||
 | 
											h.App.MjTaskClients.Put(session.SessionId, ws)
 | 
				
			||||||
						// update user's img_calls
 | 
											// update user's img_calls
 | 
				
			||||||
						h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
											h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,16 +3,18 @@ package handler
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"chatplus/core"
 | 
						"chatplus/core"
 | 
				
			||||||
	"chatplus/core/types"
 | 
						"chatplus/core/types"
 | 
				
			||||||
 | 
						"chatplus/service"
 | 
				
			||||||
	"chatplus/service/function"
 | 
						"chatplus/service/function"
 | 
				
			||||||
	"chatplus/service/oss"
 | 
						"chatplus/service/oss"
 | 
				
			||||||
	"chatplus/store"
 | 
					 | 
				
			||||||
	"chatplus/store/model"
 | 
						"chatplus/store/model"
 | 
				
			||||||
	"chatplus/utils"
 | 
						"chatplus/utils"
 | 
				
			||||||
	"chatplus/utils/resp"
 | 
						"chatplus/utils/resp"
 | 
				
			||||||
	"encoding/base64"
 | 
						"encoding/base64"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
						"github.com/go-redis/redis/v8"
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -38,25 +40,26 @@ type Image struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
type MidJourneyHandler struct {
 | 
					type MidJourneyHandler struct {
 | 
				
			||||||
	BaseHandler
 | 
						BaseHandler
 | 
				
			||||||
	leveldb         *store.LevelDB
 | 
						redis           *redis.Client
 | 
				
			||||||
	db              *gorm.DB
 | 
						db              *gorm.DB
 | 
				
			||||||
	mjFunc          function.FuncMidJourney
 | 
						mjService       *service.MjService
 | 
				
			||||||
	uploaderManager *oss.UploaderManager
 | 
						uploaderManager *oss.UploaderManager
 | 
				
			||||||
	lock            sync.Mutex
 | 
						lock            sync.Mutex
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewMidJourneyHandler(
 | 
					func NewMidJourneyHandler(
 | 
				
			||||||
	app *core.AppServer,
 | 
						app *core.AppServer,
 | 
				
			||||||
	leveldb *store.LevelDB,
 | 
						client *redis.Client,
 | 
				
			||||||
	db *gorm.DB,
 | 
						db *gorm.DB,
 | 
				
			||||||
	manager *oss.UploaderManager,
 | 
						manager *oss.UploaderManager,
 | 
				
			||||||
	functions map[string]function.Function) *MidJourneyHandler {
 | 
						mjService *service.MjService) *MidJourneyHandler {
 | 
				
			||||||
	h := MidJourneyHandler{
 | 
						h := MidJourneyHandler{
 | 
				
			||||||
		leveldb:         leveldb,
 | 
							redis:           client,
 | 
				
			||||||
		db:              db,
 | 
							db:              db,
 | 
				
			||||||
		uploaderManager: manager,
 | 
							uploaderManager: manager,
 | 
				
			||||||
		lock:            sync.Mutex{},
 | 
							lock:            sync.Mutex{},
 | 
				
			||||||
		mjFunc:          functions[types.FuncMidJourney].(function.FuncMidJourney)}
 | 
							mjService:       mjService,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	h.App = app
 | 
						h.App = app
 | 
				
			||||||
	return &h
 | 
						return &h
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -75,7 +78,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
 | 
				
			|||||||
		Content     string     `json:"content"`
 | 
							Content     string     `json:"content"`
 | 
				
			||||||
		Prompt      string     `json:"prompt"`
 | 
							Prompt      string     `json:"prompt"`
 | 
				
			||||||
		Status      TaskStatus `json:"status"`
 | 
							Status      TaskStatus `json:"status"`
 | 
				
			||||||
		Key         string     `json:"key"`
 | 
							Progress    int        `json:"progress"`
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
 | 
						if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
 | 
				
			||||||
		resp.ERROR(c, types.InvalidArgs)
 | 
							resp.ERROR(c, types.InvalidArgs)
 | 
				
			||||||
@@ -86,95 +89,142 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
 | 
				
			|||||||
	h.lock.Lock()
 | 
						h.lock.Lock()
 | 
				
			||||||
	defer h.lock.Unlock()
 | 
						defer h.lock.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// the job is saved
 | 
						taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
 | 
				
			||||||
	var job model.MidJourneyJob
 | 
						if err != nil {
 | 
				
			||||||
	res := h.db.Where("message_id = ?", data.MessageId).First(&job)
 | 
							resp.SUCCESS(c) // 过期任务,丢弃
 | 
				
			||||||
	if res.Error == nil {
 | 
					 | 
				
			||||||
		resp.SUCCESS(c)
 | 
					 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	data.Key = utils.Sha256(data.Prompt)
 | 
						var task service.MjTask
 | 
				
			||||||
	wsClient := h.App.MjTaskClients.Get(data.Key)
 | 
						err = utils.JsonDecode(taskString, &task)
 | 
				
			||||||
	//logger.Info(data.Prompt, ",", key)
 | 
						if err != nil {
 | 
				
			||||||
	if data.Status == Finished {
 | 
							resp.SUCCESS(c) // 非标准任务,丢弃
 | 
				
			||||||
		var task types.MjTask
 | 
							return
 | 
				
			||||||
		err := h.leveldb.Get(types.TaskStorePrefix+data.Key, &task)
 | 
						}
 | 
				
			||||||
		if err != nil {
 | 
					
 | 
				
			||||||
			logger.Error("error with get MidJourney task: ", err)
 | 
						if task.Src == service.TaskSrcImg { // 绘画任务
 | 
				
			||||||
 | 
							logger.Error(err)
 | 
				
			||||||
 | 
							var job model.MidJourneyJob
 | 
				
			||||||
 | 
							res := h.db.First(&job, task.Id)
 | 
				
			||||||
 | 
							if res.Error != nil {
 | 
				
			||||||
 | 
								resp.SUCCESS(c) // 非法任务,丢弃
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							job.MessageId = data.MessageId
 | 
				
			||||||
 | 
							job.ReferenceId = data.ReferenceId
 | 
				
			||||||
 | 
							job.Progress = data.Progress
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// download image
 | 
				
			||||||
 | 
							if data.Progress == 100 {
 | 
				
			||||||
 | 
								imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									resp.ERROR(c, "error with download img: "+err.Error())
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								job.ImgURL = imgURL
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								// 使用图片代理
 | 
				
			||||||
 | 
								job.ImgURL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							res = h.db.Updates(&job)
 | 
				
			||||||
 | 
							if res.Error != nil {
 | 
				
			||||||
 | 
								resp.ERROR(c, "error with update job: "+err.Error())
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							resp.SUCCESS(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						} else if task.Src == service.TaskSrcChat { // 聊天任务
 | 
				
			||||||
 | 
							var job model.MidJourneyJob
 | 
				
			||||||
 | 
							res := h.db.Where("message_id = ?", data.MessageId).First(&job)
 | 
				
			||||||
 | 
							if res.Error == nil {
 | 
				
			||||||
			resp.SUCCESS(c)
 | 
								resp.SUCCESS(c)
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if wsClient != nil && data.ReferenceId != "" {
 | 
					
 | 
				
			||||||
			content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
 | 
							wsClient := h.App.MjTaskClients.Get(task.Id)
 | 
				
			||||||
			utils.ReplyMessage(wsClient, content)
 | 
							if data.Status == Finished {
 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		// download image
 | 
					 | 
				
			||||||
		imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			logger.Error("error with download image: ", err)
 | 
					 | 
				
			||||||
			if wsClient != nil && data.ReferenceId != "" {
 | 
								if wsClient != nil && data.ReferenceId != "" {
 | 
				
			||||||
				content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
 | 
									content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
 | 
				
			||||||
				utils.ReplyMessage(wsClient, content)
 | 
									utils.ReplyMessage(wsClient, content)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			resp.ERROR(c, err.Error())
 | 
								// download image
 | 
				
			||||||
 | 
								imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									logger.Error("error with download image: ", err)
 | 
				
			||||||
 | 
									if wsClient != nil && data.ReferenceId != "" {
 | 
				
			||||||
 | 
										content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
 | 
				
			||||||
 | 
										utils.ReplyMessage(wsClient, content)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									resp.ERROR(c, err.Error())
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								data.Image.URL = imgURL
 | 
				
			||||||
 | 
								message := model.HistoryMessage{
 | 
				
			||||||
 | 
									UserId:     uint(task.UserId),
 | 
				
			||||||
 | 
									ChatId:     task.ChatId,
 | 
				
			||||||
 | 
									RoleId:     uint(task.RoleId),
 | 
				
			||||||
 | 
									Type:       types.MjMsg,
 | 
				
			||||||
 | 
									Icon:       task.Icon,
 | 
				
			||||||
 | 
									Content:    utils.JsonEncode(data),
 | 
				
			||||||
 | 
									Tokens:     0,
 | 
				
			||||||
 | 
									UseContext: false,
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								res := h.db.Create(&message)
 | 
				
			||||||
 | 
								if res.Error != nil {
 | 
				
			||||||
 | 
									logger.Error("error with save chat history message: ", res.Error)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// save the job
 | 
				
			||||||
 | 
								job.UserId = task.UserId
 | 
				
			||||||
 | 
								job.MessageId = data.MessageId
 | 
				
			||||||
 | 
								job.ReferenceId = data.ReferenceId
 | 
				
			||||||
 | 
								job.Prompt = data.Prompt
 | 
				
			||||||
 | 
								job.ImgURL = imgURL
 | 
				
			||||||
 | 
								job.Progress = data.Progress
 | 
				
			||||||
 | 
								job.CreatedAt = time.Now()
 | 
				
			||||||
 | 
								res = h.db.Create(&job)
 | 
				
			||||||
 | 
								if res.Error != nil {
 | 
				
			||||||
 | 
									logger.Error("error with save MidJourney Job: ", res.Error)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if wsClient == nil { // 客户端断线,则丢弃
 | 
				
			||||||
 | 
								logger.Errorf("Client is offline: %+v", data)
 | 
				
			||||||
 | 
								resp.SUCCESS(c, "Client is offline")
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		data.Image.URL = imgURL
 | 
							if data.Status == Finished {
 | 
				
			||||||
		message := model.HistoryMessage{
 | 
								utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
 | 
				
			||||||
			UserId:     task.UserId,
 | 
								utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
 | 
				
			||||||
			ChatId:     task.ChatId,
 | 
								// delete client
 | 
				
			||||||
			RoleId:     task.RoleId,
 | 
								h.App.MjTaskClients.Delete(task.Id)
 | 
				
			||||||
			Type:       types.MjMsg,
 | 
							} else {
 | 
				
			||||||
			Icon:       task.Icon,
 | 
								//// 使用代理临时转发图片
 | 
				
			||||||
			Content:    utils.JsonEncode(data),
 | 
								//if data.Image.URL != "" {
 | 
				
			||||||
			Tokens:     0,
 | 
								//	image, err := utils.DownloadImage(data.Image.URL, h.App.Config.ProxyURL)
 | 
				
			||||||
			UseContext: false,
 | 
								//	if err == nil {
 | 
				
			||||||
		}
 | 
								//		data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
				
			||||||
		res := h.db.Create(&message)
 | 
								//	}
 | 
				
			||||||
		if res.Error != nil {
 | 
								//}
 | 
				
			||||||
			logger.Error("error with save chat history message: ", res.Error)
 | 
								data.Image.URL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL)
 | 
				
			||||||
		}
 | 
								utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
 | 
				
			||||||
 | 
					 | 
				
			||||||
		// save the job
 | 
					 | 
				
			||||||
		job.UserId = task.UserId
 | 
					 | 
				
			||||||
		job.ChatId = task.ChatId
 | 
					 | 
				
			||||||
		job.MessageId = data.MessageId
 | 
					 | 
				
			||||||
		job.ReferenceId = data.ReferenceId
 | 
					 | 
				
			||||||
		job.Content = data.Content
 | 
					 | 
				
			||||||
		job.Prompt = data.Prompt
 | 
					 | 
				
			||||||
		job.Image = utils.JsonEncode(data.Image)
 | 
					 | 
				
			||||||
		job.Hash = data.Image.Hash
 | 
					 | 
				
			||||||
		job.CreatedAt = time.Now()
 | 
					 | 
				
			||||||
		res = h.db.Create(&job)
 | 
					 | 
				
			||||||
		if res.Error != nil {
 | 
					 | 
				
			||||||
			logger.Error("error with save MidJourney Job: ", res.Error)
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							resp.SUCCESS(c, "SUCCESS")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if wsClient == nil { // 客户端断线,则丢弃
 | 
					}
 | 
				
			||||||
		logger.Errorf("Client is offline: %+v", data)
 | 
					
 | 
				
			||||||
		resp.SUCCESS(c, "Client is offline")
 | 
					func (h *MidJourneyHandler) Proxy(c *gin.Context) {
 | 
				
			||||||
 | 
						url := c.Query("url")
 | 
				
			||||||
 | 
						image, err := utils.DownloadImage(url, h.App.Config.ProxyURL)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							c.String(http.StatusOK, err.Error())
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						c.String(http.StatusOK, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(image))
 | 
				
			||||||
	if data.Status == Finished {
 | 
					 | 
				
			||||||
		utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
 | 
					 | 
				
			||||||
		utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
 | 
					 | 
				
			||||||
		// delete client
 | 
					 | 
				
			||||||
		h.App.MjTaskClients.Delete(data.Key)
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		// 使用代理临时转发图片
 | 
					 | 
				
			||||||
		if data.Image.URL != "" {
 | 
					 | 
				
			||||||
			image, err := utils.DownloadImage(data.Image.URL, h.App.Config.ProxyURL)
 | 
					 | 
				
			||||||
			if err == nil {
 | 
					 | 
				
			||||||
				data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	resp.SUCCESS(c, "SUCCESS")
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type reqVo struct {
 | 
					type reqVo struct {
 | 
				
			||||||
@@ -201,7 +251,12 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := h.mjFunc.Upscale(function.MjUpscaleReq{
 | 
						h.mjService.PushTask(service.MjTask{
 | 
				
			||||||
 | 
							Index:       data.Index,
 | 
				
			||||||
 | 
							MessageId:   data.MessageId,
 | 
				
			||||||
 | 
							MessageHash: data.MessageHash,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						err := n.Upscale(function.MjUpscaleReq{
 | 
				
			||||||
		Index:       data.Index,
 | 
							Index:       data.Index,
 | 
				
			||||||
		MessageId:   data.MessageId,
 | 
							MessageId:   data.MessageId,
 | 
				
			||||||
		MessageHash: data.MessageHash,
 | 
							MessageHash: data.MessageHash,
 | 
				
			||||||
@@ -211,7 +266,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	content := fmt.Sprintf("**%s** 已推送 Upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
 | 
						content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
 | 
				
			||||||
	utils.ReplyMessage(wsClient, content)
 | 
						utils.ReplyMessage(wsClient, content)
 | 
				
			||||||
	if h.App.MjTaskClients.Get(data.Key) == nil {
 | 
						if h.App.MjTaskClients.Get(data.Key) == nil {
 | 
				
			||||||
		h.App.MjTaskClients.Put(data.Key, wsClient)
 | 
							h.App.MjTaskClients.Put(data.Key, wsClient)
 | 
				
			||||||
@@ -242,7 +297,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
 | 
				
			|||||||
		resp.ERROR(c, err.Error())
 | 
							resp.ERROR(c, err.Error())
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	content := fmt.Sprintf("**%s** 已推送 Variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
 | 
						content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
 | 
				
			||||||
	utils.ReplyMessage(wsClient, content)
 | 
						utils.ReplyMessage(wsClient, content)
 | 
				
			||||||
	if h.App.MjTaskClients.Get(data.Key) == nil {
 | 
						if h.App.MjTaskClients.Get(data.Key) == nil {
 | 
				
			||||||
		h.App.MjTaskClients.Put(data.Key, wsClient)
 | 
							h.App.MjTaskClients.Put(data.Key, wsClient)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -131,6 +131,13 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			|||||||
				utils.ReplyMessage(ws, "")
 | 
									utils.ReplyMessage(ws, "")
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				f := h.App.Functions[functionName]
 | 
									f := h.App.Functions[functionName]
 | 
				
			||||||
 | 
									if functionName == types.FuncMidJourney {
 | 
				
			||||||
 | 
										params["user_id"] = userVo.Id
 | 
				
			||||||
 | 
										params["role_id"] = role.Id
 | 
				
			||||||
 | 
										params["chat_id"] = session.ChatId
 | 
				
			||||||
 | 
										params["icon"] = "/images/avatar/mid_journey.png"
 | 
				
			||||||
 | 
										params["session_id"] = session.SessionId
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
				data, err := f.Invoke(params)
 | 
									data, err := f.Invoke(params)
 | 
				
			||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					msg := "调用函数出错:" + err.Error()
 | 
										msg := "调用函数出错:" + err.Error()
 | 
				
			||||||
@@ -142,22 +149,8 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			|||||||
				} else {
 | 
									} else {
 | 
				
			||||||
					content := data
 | 
										content := data
 | 
				
			||||||
					if functionName == types.FuncMidJourney {
 | 
										if functionName == types.FuncMidJourney {
 | 
				
			||||||
						key := utils.Sha256(data)
 | 
					 | 
				
			||||||
						logger.Debug(data, ",", key)
 | 
					 | 
				
			||||||
						// add task for MidJourney
 | 
					 | 
				
			||||||
						h.App.MjTaskClients.Put(key, ws)
 | 
					 | 
				
			||||||
						task := types.MjTask{
 | 
					 | 
				
			||||||
							UserId: userVo.Id,
 | 
					 | 
				
			||||||
							RoleId: role.Id,
 | 
					 | 
				
			||||||
							Icon:   "/images/avatar/mid_journey.png",
 | 
					 | 
				
			||||||
							ChatId: session.ChatId,
 | 
					 | 
				
			||||||
						}
 | 
					 | 
				
			||||||
						err := h.leveldb.Put(types.TaskStorePrefix+key, task)
 | 
					 | 
				
			||||||
						if err != nil {
 | 
					 | 
				
			||||||
							logger.Error("error with store MidJourney task: ", err)
 | 
					 | 
				
			||||||
						}
 | 
					 | 
				
			||||||
						content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
 | 
											content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
 | 
				
			||||||
 | 
											h.App.MjTaskClients.Put(session.SessionId, ws)
 | 
				
			||||||
						// update user's img_calls
 | 
											// update user's img_calls
 | 
				
			||||||
						h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
											h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										14
									
								
								api/main.go
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								api/main.go
									
									
									
									
									
								
							@@ -135,6 +135,12 @@ func main() {
 | 
				
			|||||||
			return service.NewCaptchaService(config.ApiConfig)
 | 
								return service.NewCaptchaService(config.ApiConfig)
 | 
				
			||||||
		}),
 | 
							}),
 | 
				
			||||||
		fx.Provide(oss.NewUploaderManager),
 | 
							fx.Provide(oss.NewUploaderManager),
 | 
				
			||||||
 | 
							fx.Provide(service.NewMjService),
 | 
				
			||||||
 | 
							fx.Provide(func(mjService *service.MjService) {
 | 
				
			||||||
 | 
								go func() {
 | 
				
			||||||
 | 
									mjService.Run()
 | 
				
			||||||
 | 
								}()
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 注册路由
 | 
							// 注册路由
 | 
				
			||||||
		fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
 | 
							fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
 | 
				
			||||||
@@ -183,9 +189,11 @@ func main() {
 | 
				
			|||||||
			group.POST("verify", h.Verify)
 | 
								group.POST("verify", h.Verify)
 | 
				
			||||||
		}),
 | 
							}),
 | 
				
			||||||
		fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
 | 
							fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
 | 
				
			||||||
			s.Engine.POST("/api/mj/notify", h.Notify)
 | 
								group := s.Engine.Group("/api/mj/")
 | 
				
			||||||
			s.Engine.POST("/api/mj/upscale", h.Upscale)
 | 
								group.POST("notify", h.Notify)
 | 
				
			||||||
			s.Engine.POST("/api/mj/variation", h.Variation)
 | 
								group.POST("upscale", h.Upscale)
 | 
				
			||||||
 | 
								group.POST("variation", h.Variation)
 | 
				
			||||||
 | 
								group.GET("proxy", h.Proxy)
 | 
				
			||||||
		}),
 | 
							}),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 管理后台控制器
 | 
							// 管理后台控制器
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										64
									
								
								api/service/function/func_mj.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								api/service/function/func_mj.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,64 @@
 | 
				
			|||||||
 | 
					package function
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"chatplus/service"
 | 
				
			||||||
 | 
						"chatplus/utils"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// AI 绘画函数
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type FuncMidJourney struct {
 | 
				
			||||||
 | 
						name    string
 | 
				
			||||||
 | 
						service *service.MjService
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewMidJourneyFunc(mjService *service.MjService) FuncMidJourney {
 | 
				
			||||||
 | 
						return FuncMidJourney{
 | 
				
			||||||
 | 
							name:    "MidJourney AI 绘画",
 | 
				
			||||||
 | 
							service: mjService}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
 | 
				
			||||||
 | 
						logger.Infof("MJ 绘画参数:%+v", params)
 | 
				
			||||||
 | 
						prompt := utils.InterfaceToString(params["prompt"])
 | 
				
			||||||
 | 
						if !utils.IsEmptyValue(params["--ar"]) {
 | 
				
			||||||
 | 
							prompt = fmt.Sprintf("%s --ar %s", prompt, params["--ar"])
 | 
				
			||||||
 | 
							delete(params, "--ar")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !utils.IsEmptyValue(params["--s"]) {
 | 
				
			||||||
 | 
							prompt = fmt.Sprintf("%s --s %s", prompt, params["--s"])
 | 
				
			||||||
 | 
							delete(params, "--s")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !utils.IsEmptyValue(params["--seed"]) {
 | 
				
			||||||
 | 
							prompt = fmt.Sprintf("%s --seed %s", prompt, params["--seed"])
 | 
				
			||||||
 | 
							delete(params, "--seed")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !utils.IsEmptyValue(params["--no"]) {
 | 
				
			||||||
 | 
							prompt = fmt.Sprintf("%s --no %s", prompt, params["--no"])
 | 
				
			||||||
 | 
							delete(params, "--no")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !utils.IsEmptyValue(params["--niji"]) {
 | 
				
			||||||
 | 
							prompt = fmt.Sprintf("%s --niji %s", prompt, params["--niji"])
 | 
				
			||||||
 | 
							delete(params, "--niji")
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							prompt = prompt + " --v 5.2"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						f.service.PushTask(service.MjTask{
 | 
				
			||||||
 | 
							Id:     utils.InterfaceToString(params["session_id"]),
 | 
				
			||||||
 | 
							Src:    service.TaskSrcChat,
 | 
				
			||||||
 | 
							Prompt: prompt,
 | 
				
			||||||
 | 
							UserId: utils.IntValue(utils.InterfaceToString(params["user_id"]), 0),
 | 
				
			||||||
 | 
							RoleId: utils.IntValue(utils.InterfaceToString(params["role_id"]), 0),
 | 
				
			||||||
 | 
							Icon:   utils.InterfaceToString(params["icon"]),
 | 
				
			||||||
 | 
							ChatId: utils.InterfaceToString(params["chat_id"]),
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						return prompt, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (f FuncMidJourney) Name() string {
 | 
				
			||||||
 | 
						return f.name
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var _ Function = &FuncMidJourney{}
 | 
				
			||||||
@@ -1,129 +0,0 @@
 | 
				
			|||||||
package function
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"chatplus/core/types"
 | 
					 | 
				
			||||||
	"chatplus/utils"
 | 
					 | 
				
			||||||
	"errors"
 | 
					 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"github.com/imroc/req/v3"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// AI 绘画函数
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type FuncMidJourney struct {
 | 
					 | 
				
			||||||
	name   string
 | 
					 | 
				
			||||||
	config types.ChatPlusExtConfig
 | 
					 | 
				
			||||||
	client *req.Client
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func NewMidJourneyFunc(config types.ChatPlusExtConfig) FuncMidJourney {
 | 
					 | 
				
			||||||
	return FuncMidJourney{
 | 
					 | 
				
			||||||
		name:   "MidJourney AI 绘画",
 | 
					 | 
				
			||||||
		config: config,
 | 
					 | 
				
			||||||
		client: req.C().SetTimeout(30 * time.Second)}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
 | 
					 | 
				
			||||||
	if f.config.Token == "" {
 | 
					 | 
				
			||||||
		return "", errors.New("无效的 API Token")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	logger.Infof("MJ 绘画参数:%+v", params)
 | 
					 | 
				
			||||||
	prompt := utils.InterfaceToString(params["prompt"])
 | 
					 | 
				
			||||||
	if !utils.IsEmptyValue(params["--ar"]) {
 | 
					 | 
				
			||||||
		prompt = fmt.Sprintf("%s --ar %s", prompt, params["--ar"])
 | 
					 | 
				
			||||||
		delete(params, "--ar")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if !utils.IsEmptyValue(params["--s"]) {
 | 
					 | 
				
			||||||
		prompt = fmt.Sprintf("%s --s %s", prompt, params["--s"])
 | 
					 | 
				
			||||||
		delete(params, "--s")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if !utils.IsEmptyValue(params["--seed"]) {
 | 
					 | 
				
			||||||
		prompt = fmt.Sprintf("%s --seed %s", prompt, params["--seed"])
 | 
					 | 
				
			||||||
		delete(params, "--seed")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if !utils.IsEmptyValue(params["--no"]) {
 | 
					 | 
				
			||||||
		prompt = fmt.Sprintf("%s --no %s", prompt, params["--no"])
 | 
					 | 
				
			||||||
		delete(params, "--no")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if !utils.IsEmptyValue(params["--niji"]) {
 | 
					 | 
				
			||||||
		prompt = fmt.Sprintf("%s --niji %s", prompt, params["--niji"])
 | 
					 | 
				
			||||||
		delete(params, "--niji")
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		prompt = prompt + " --v 5.2"
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	params["prompt"] = prompt
 | 
					 | 
				
			||||||
	url := fmt.Sprintf("%s/api/mj/image", f.config.ApiURL)
 | 
					 | 
				
			||||||
	var res types.BizVo
 | 
					 | 
				
			||||||
	r, err := f.client.R().
 | 
					 | 
				
			||||||
		SetHeader("Authorization", f.config.Token).
 | 
					 | 
				
			||||||
		SetHeader("Content-Type", "application/json").
 | 
					 | 
				
			||||||
		SetBody(params).
 | 
					 | 
				
			||||||
		SetSuccessResult(&res).Post(url)
 | 
					 | 
				
			||||||
	if err != nil || r.IsErrorState() {
 | 
					 | 
				
			||||||
		return "", fmt.Errorf("%v%v", r.String(), err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if res.Code != types.Success {
 | 
					 | 
				
			||||||
		return "", errors.New(res.Message)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return prompt, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type MjUpscaleReq struct {
 | 
					 | 
				
			||||||
	Index       int32  `json:"index"`
 | 
					 | 
				
			||||||
	MessageId   string `json:"message_id"`
 | 
					 | 
				
			||||||
	MessageHash string `json:"message_hash"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (f FuncMidJourney) Upscale(upReq MjUpscaleReq) error {
 | 
					 | 
				
			||||||
	url := fmt.Sprintf("%s/api/mj/upscale", f.config.ApiURL)
 | 
					 | 
				
			||||||
	var res types.BizVo
 | 
					 | 
				
			||||||
	r, err := f.client.R().
 | 
					 | 
				
			||||||
		SetHeader("Authorization", f.config.Token).
 | 
					 | 
				
			||||||
		SetHeader("Content-Type", "application/json").
 | 
					 | 
				
			||||||
		SetBody(upReq).
 | 
					 | 
				
			||||||
		SetSuccessResult(&res).Post(url)
 | 
					 | 
				
			||||||
	if err != nil || r.IsErrorState() {
 | 
					 | 
				
			||||||
		return fmt.Errorf("%v%v", r.String(), err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if res.Code != types.Success {
 | 
					 | 
				
			||||||
		return errors.New(res.Message)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type MjVariationReq struct {
 | 
					 | 
				
			||||||
	Index       int32  `json:"index"`
 | 
					 | 
				
			||||||
	MessageId   string `json:"message_id"`
 | 
					 | 
				
			||||||
	MessageHash string `json:"message_hash"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (f FuncMidJourney) Variation(upReq MjVariationReq) error {
 | 
					 | 
				
			||||||
	url := fmt.Sprintf("%s/api/mj/variation", f.config.ApiURL)
 | 
					 | 
				
			||||||
	var res types.BizVo
 | 
					 | 
				
			||||||
	r, err := f.client.R().
 | 
					 | 
				
			||||||
		SetHeader("Authorization", f.config.Token).
 | 
					 | 
				
			||||||
		SetHeader("Content-Type", "application/json").
 | 
					 | 
				
			||||||
		SetBody(upReq).
 | 
					 | 
				
			||||||
		SetSuccessResult(&res).Post(url)
 | 
					 | 
				
			||||||
	if err != nil || r.IsErrorState() {
 | 
					 | 
				
			||||||
		return fmt.Errorf("%v%v", r.String(), err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if res.Code != types.Success {
 | 
					 | 
				
			||||||
		return errors.New(res.Message)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (f FuncMidJourney) Name() string {
 | 
					 | 
				
			||||||
	return f.name
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
var _ Function = &FuncMidJourney{}
 | 
					 | 
				
			||||||
							
								
								
									
										189
									
								
								api/service/mj_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										189
									
								
								api/service/mj_service.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,189 @@
 | 
				
			|||||||
 | 
					package service
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"chatplus/core/types"
 | 
				
			||||||
 | 
						logger2 "chatplus/logger"
 | 
				
			||||||
 | 
						"chatplus/store"
 | 
				
			||||||
 | 
						"chatplus/utils"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/go-redis/redis/v8"
 | 
				
			||||||
 | 
						"github.com/imroc/req/v3"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var logger = logger2.GetLogger()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// MJ 绘画服务
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const MjRunningJobKey = "MidJourney_Running_Job"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type TaskType string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						Image     = TaskType("image")
 | 
				
			||||||
 | 
						Upscale   = TaskType("upscale")
 | 
				
			||||||
 | 
						Variation = TaskType("variation")
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type TaskSrc string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						TaskSrcChat = TaskSrc("chat")
 | 
				
			||||||
 | 
						TaskSrcImg  = TaskSrc("img")
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type MjTask struct {
 | 
				
			||||||
 | 
						Id          string   `json:"id"`
 | 
				
			||||||
 | 
						Src         TaskSrc  `json:"src"`
 | 
				
			||||||
 | 
						Type        TaskType `json:"type"`
 | 
				
			||||||
 | 
						UserId      int      `json:"user_id"`
 | 
				
			||||||
 | 
						Prompt      string   `json:"prompt,omitempty"`
 | 
				
			||||||
 | 
						ChatId      string   `json:"chat_id,omitempty"`
 | 
				
			||||||
 | 
						RoleId      int      `json:"role_id,omitempty"`
 | 
				
			||||||
 | 
						Icon        string   `json:"icon,omitempty"`
 | 
				
			||||||
 | 
						Index       int32    `json:"index,omitempty"`
 | 
				
			||||||
 | 
						MessageId   string   `json:"message_id,omitempty"`
 | 
				
			||||||
 | 
						MessageHash string   `json:"message_hash,omitempty"`
 | 
				
			||||||
 | 
						RetryCount  int      `json:"retry_count"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type MjService struct {
 | 
				
			||||||
 | 
						config    types.ChatPlusExtConfig
 | 
				
			||||||
 | 
						client    *req.Client
 | 
				
			||||||
 | 
						taskQueue *store.RedisQueue
 | 
				
			||||||
 | 
						redis     *redis.Client
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewMjService(config types.ChatPlusExtConfig, client *redis.Client) *MjService {
 | 
				
			||||||
 | 
						return &MjService{
 | 
				
			||||||
 | 
							config:    config,
 | 
				
			||||||
 | 
							redis:     client,
 | 
				
			||||||
 | 
							taskQueue: store.NewRedisQueue("midjourney_task_queue", client),
 | 
				
			||||||
 | 
							client:    req.C().SetTimeout(30 * time.Second)}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *MjService) Run() {
 | 
				
			||||||
 | 
						ctx := context.Background()
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							_, err := s.redis.Get(ctx, MjRunningJobKey).Result()
 | 
				
			||||||
 | 
							if err == nil { // a task is running, waiting for finish
 | 
				
			||||||
 | 
								time.Sleep(time.Second * 3)
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							var task MjTask
 | 
				
			||||||
 | 
							err = s.taskQueue.LPop(&task)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								logger.Errorf("taking task with error: %v", err)
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							switch task.Type {
 | 
				
			||||||
 | 
							case Image:
 | 
				
			||||||
 | 
								err = s.image(task.Prompt)
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							case Upscale:
 | 
				
			||||||
 | 
								err = s.upscale(MjUpscaleReq{
 | 
				
			||||||
 | 
									Index:       task.Index,
 | 
				
			||||||
 | 
									MessageId:   task.MessageId,
 | 
				
			||||||
 | 
									MessageHash: task.MessageHash,
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							case Variation:
 | 
				
			||||||
 | 
								err = s.variation(MjVariationReq{
 | 
				
			||||||
 | 
									Index:       task.Index,
 | 
				
			||||||
 | 
									MessageId:   task.MessageId,
 | 
				
			||||||
 | 
									MessageHash: task.MessageHash,
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								if task.RetryCount > 5 {
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								task.RetryCount += 1
 | 
				
			||||||
 | 
								time.Sleep(time.Second)
 | 
				
			||||||
 | 
								s.taskQueue.RPush(task)
 | 
				
			||||||
 | 
								// TODO: 执行失败通知聊天客户端
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// 锁定任务执行通道,直到任务超时(10分钟)
 | 
				
			||||||
 | 
							s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Second*600)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *MjService) PushTask(task MjTask) {
 | 
				
			||||||
 | 
						s.taskQueue.RPush(task)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *MjService) image(prompt string) error {
 | 
				
			||||||
 | 
						logger.Infof("MJ 绘画参数:%+v", prompt)
 | 
				
			||||||
 | 
						body := map[string]string{"prompt": prompt}
 | 
				
			||||||
 | 
						url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL)
 | 
				
			||||||
 | 
						var res types.BizVo
 | 
				
			||||||
 | 
						r, err := s.client.R().
 | 
				
			||||||
 | 
							SetHeader("Authorization", s.config.Token).
 | 
				
			||||||
 | 
							SetHeader("Content-Type", "application/json").
 | 
				
			||||||
 | 
							SetBody(body).
 | 
				
			||||||
 | 
							SetSuccessResult(&res).Post(url)
 | 
				
			||||||
 | 
						if err != nil || r.IsErrorState() {
 | 
				
			||||||
 | 
							return fmt.Errorf("%v%v", r.String(), err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if res.Code != types.Success {
 | 
				
			||||||
 | 
							return errors.New(res.Message)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type MjUpscaleReq struct {
 | 
				
			||||||
 | 
						Index       int32  `json:"index"`
 | 
				
			||||||
 | 
						MessageId   string `json:"message_id"`
 | 
				
			||||||
 | 
						MessageHash string `json:"message_hash"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *MjService) upscale(upReq MjUpscaleReq) error {
 | 
				
			||||||
 | 
						url := fmt.Sprintf("%s/api/mj/upscale", s.config.ApiURL)
 | 
				
			||||||
 | 
						var res types.BizVo
 | 
				
			||||||
 | 
						r, err := s.client.R().
 | 
				
			||||||
 | 
							SetHeader("Authorization", s.config.Token).
 | 
				
			||||||
 | 
							SetHeader("Content-Type", "application/json").
 | 
				
			||||||
 | 
							SetBody(upReq).
 | 
				
			||||||
 | 
							SetSuccessResult(&res).Post(url)
 | 
				
			||||||
 | 
						if err != nil || r.IsErrorState() {
 | 
				
			||||||
 | 
							return fmt.Errorf("%v%v", r.String(), err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if res.Code != types.Success {
 | 
				
			||||||
 | 
							return errors.New(res.Message)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type MjVariationReq struct {
 | 
				
			||||||
 | 
						Index       int32  `json:"index"`
 | 
				
			||||||
 | 
						MessageId   string `json:"message_id"`
 | 
				
			||||||
 | 
						MessageHash string `json:"message_hash"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *MjService) variation(upReq MjVariationReq) error {
 | 
				
			||||||
 | 
						url := fmt.Sprintf("%s/api/mj/variation", s.config.ApiURL)
 | 
				
			||||||
 | 
						var res types.BizVo
 | 
				
			||||||
 | 
						r, err := s.client.R().
 | 
				
			||||||
 | 
							SetHeader("Authorization", s.config.Token).
 | 
				
			||||||
 | 
							SetHeader("Content-Type", "application/json").
 | 
				
			||||||
 | 
							SetBody(upReq).
 | 
				
			||||||
 | 
							SetSuccessResult(&res).Post(url)
 | 
				
			||||||
 | 
						if err != nil || r.IsErrorState() {
 | 
				
			||||||
 | 
							return fmt.Errorf("%v%v", r.String(), err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if res.Code != types.Success {
 | 
				
			||||||
 | 
							return errors.New(res.Message)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -4,14 +4,13 @@ import "time"
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
type MidJourneyJob struct {
 | 
					type MidJourneyJob struct {
 | 
				
			||||||
	Id          uint `gorm:"primarykey;column:id"`
 | 
						Id          uint `gorm:"primarykey;column:id"`
 | 
				
			||||||
	UserId      uint
 | 
						UserId      int
 | 
				
			||||||
	ChatId      string
 | 
					 | 
				
			||||||
	MessageId   string
 | 
						MessageId   string
 | 
				
			||||||
	ReferenceId string
 | 
						ReferenceId string
 | 
				
			||||||
	Hash        string
 | 
						ImgURL      string
 | 
				
			||||||
	Content     string
 | 
						Hash        string // message hash
 | 
				
			||||||
 | 
						Progress    int
 | 
				
			||||||
	Prompt      string
 | 
						Prompt      string
 | 
				
			||||||
	Image       string
 | 
					 | 
				
			||||||
	CreatedAt   time.Time
 | 
						CreatedAt   time.Time
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										41
									
								
								api/store/redis_queue.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								api/store/redis_queue.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,41 @@
 | 
				
			|||||||
 | 
					package store
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"chatplus/utils"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"github.com/go-redis/redis/v8"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RedisQueue struct {
 | 
				
			||||||
 | 
						name   string
 | 
				
			||||||
 | 
						client *redis.Client
 | 
				
			||||||
 | 
						ctx    context.Context
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewRedisQueue(name string, client *redis.Client) *RedisQueue {
 | 
				
			||||||
 | 
						return &RedisQueue{name: name, client: client, ctx: context.Background()}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (q *RedisQueue) RPush(value interface{}) {
 | 
				
			||||||
 | 
						q.client.RPush(q.ctx, q.name, utils.JsonEncode(value))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (q *RedisQueue) LPush(value interface{}) {
 | 
				
			||||||
 | 
						q.client.LPush(q.ctx, q.name, utils.JsonEncode(value))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (q *RedisQueue) LPop(value interface{}) error {
 | 
				
			||||||
 | 
						result, err := q.client.BLPop(q.ctx, 0, q.name).Result()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return utils.JsonDecode(result[1], value)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (q *RedisQueue) RPop(value interface{}) error {
 | 
				
			||||||
 | 
						result, err := q.client.BRPop(q.ctx, 0, q.name).Result()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return utils.JsonDecode(result[1], value)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										4
									
								
								database/update-3.1.3.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								database/update-3.1.3.sql
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,4 @@
 | 
				
			|||||||
 | 
					ALTER TABLE `chatgpt_mj_jobs` DROP `image`;
 | 
				
			||||||
 | 
					ALTER TABLE `chatgpt_mj_jobs` ADD `progress` SMALLINT(5) NULL DEFAULT '0' COMMENT '任务进度' AFTER `prompt`;
 | 
				
			||||||
 | 
					ALTER TABLE `chatgpt_mj_jobs` ADD `hash` VARCHAR(100) NULL DEFAULT NULL COMMENT 'message hash' AFTER `prompt`;
 | 
				
			||||||
 | 
					ALTER TABLE `chatgpt_mj_jobs` ADD `img_url` VARCHAR(255) NULL DEFAULT NULL COMMENT '图片URL' AFTER `prompt`;
 | 
				
			||||||
@@ -109,7 +109,6 @@ const send = (url, index) => {
 | 
				
			|||||||
    message_id: data.value?.["message_id"],
 | 
					    message_id: data.value?.["message_id"],
 | 
				
			||||||
    message_hash: data.value?.["image"]?.hash,
 | 
					    message_hash: data.value?.["image"]?.hash,
 | 
				
			||||||
    session_id: getSessionId(),
 | 
					    session_id: getSessionId(),
 | 
				
			||||||
    key: data.value?.["key"],
 | 
					 | 
				
			||||||
    prompt: data.value?.["prompt"],
 | 
					    prompt: data.value?.["prompt"],
 | 
				
			||||||
  }).then(() => {
 | 
					  }).then(() => {
 | 
				
			||||||
    ElMessage.success("任务推送成功,请耐心等待任务执行...")
 | 
					    ElMessage.success("任务推送成功,请耐心等待任务执行...")
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user