diff --git a/api/Makefile b/api/Makefile index 70510fbe..897e6e33 100644 --- a/api/Makefile +++ b/api/Makefile @@ -12,7 +12,7 @@ linux: .PHONY: linux darwin: - CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -o bin/$(NAME)-amd64-darwin main.go + CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/$(NAME)-amd64-darwin main.go .PHONY: darwin clean: diff --git a/api/core/app_server.go b/api/core/app_server.go index 85d1db31..99b1d2cd 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -151,7 +151,7 @@ func corsMiddleware() gin.HandlerFunc { c.Header("Access-Control-Allow-Origin", origin) c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE") //允许跨域设置可以返回其他子段,可以自定义字段 - c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, ChatGPT-TOKEN, ADMIN-SESSION-TOKEN") + c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type") // 允许浏览器(客户端)可以解析的头部 (重要) c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers") //设置缓存时间 diff --git a/api/core/types/client.go b/api/core/types/client.go index 90a21f4a..d1c80cd9 100644 --- a/api/core/types/client.go +++ b/api/core/types/client.go @@ -6,18 +6,14 @@ import ( "sync" ) -var ErrConClosed = errors.New("connection closed") - -type Client interface { - Close() -} +var ErrConClosed = errors.New("connection Closed") // WsClient websocket client type WsClient struct { Conn *websocket.Conn lock sync.Mutex mt int - closed bool + Closed bool } func NewWsClient(conn *websocket.Conn) *WsClient { @@ -25,7 +21,7 @@ func NewWsClient(conn *websocket.Conn) *WsClient { Conn: conn, lock: sync.Mutex{}, mt: 2, // fixed bug for 'Invalid UTF-8 in text frame' - closed: false, + Closed: false, } } @@ -33,7 +29,7 @@ func (wc *WsClient) Send(message []byte) error { wc.lock.Lock() defer wc.lock.Unlock() - if wc.closed { + if wc.Closed { return ErrConClosed } @@ -41,7 +37,7 @@ func (wc *WsClient) Send(message []byte) error { } func (wc *WsClient) Receive() (int, []byte, error) { - if wc.closed { + if wc.Closed { return 0, nil, ErrConClosed } @@ -52,10 +48,10 @@ func (wc *WsClient) Close() { wc.lock.Lock() defer wc.lock.Unlock() - if wc.closed { + if wc.Closed { return } _ = wc.Conn.Close() - wc.closed = true + wc.Closed = true } diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 493b6a0c..d7ba91c9 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -3,12 +3,14 @@ package handler import ( "chatplus/core" "chatplus/core/types" + "chatplus/service/function" "chatplus/store" "chatplus/store/model" "chatplus/utils" "chatplus/utils/resp" "github.com/gin-gonic/gin" "gorm.io/gorm" + "time" ) type TaskStatus string @@ -34,10 +36,11 @@ type MidJourneyHandler struct { BaseHandler leveldb *store.LevelDB db *gorm.DB + mjFunc function.FuncMidJourney } -func NewMidJourneyHandler(app *core.AppServer, leveldb *store.LevelDB, db *gorm.DB) *MidJourneyHandler { - h := MidJourneyHandler{leveldb: leveldb, db: db} +func NewMidJourneyHandler(app *core.AppServer, leveldb *store.LevelDB, db *gorm.DB, functions map[string]function.Function) *MidJourneyHandler { + h := MidJourneyHandler{leveldb: leveldb, db: db, mjFunc: functions[types.FuncMidJourney].(function.FuncMidJourney)} h.App = app return &h } @@ -50,34 +53,43 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { } var data struct { - Type string `json:"type"` - MessageId string `json:"message_id"` - Image Image `json:"image"` - Content string `json:"content"` - Prompt string `json:"prompt"` - Status TaskStatus `json:"status"` - Key string `json:"key"` + MessageId string `json:"message_id"` + ReferenceId string `json:"reference_id"` + Image Image `json:"image"` + Content string `json:"content"` + Prompt string `json:"prompt"` + Status TaskStatus `json:"status"` + Key string `json:"key"` } if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" { resp.ERROR(c, types.InvalidArgs) return } - key := utils.Sha256(data.Prompt) - data.Key = key + logger.Infof("收到 MidJourney 回调请求:%+v", data) + + // the job is saved + var job model.MidJourneyJob + res := h.db.Where("message_id = ?", data.MessageId).First(&job) + if res.Error == nil { + resp.SUCCESS(c) + return + } + + data.Key = utils.Sha256(data.Prompt) //logger.Info(data.Prompt, ",", key) if data.Status == Finished { var task types.MjTask - err := h.leveldb.Get(types.TaskStorePrefix+key, &task) + err := h.leveldb.Get(types.TaskStorePrefix+data.Key, &task) if err != nil { logger.Error("error with get MidJourney task: ", err) - resp.ERROR(c) + resp.ERROR(c, err.Error()) return } // TODO: 是否需要把图片下载到本地服务器? - historyUserMsg := model.HistoryMessage{ + message := model.HistoryMessage{ UserId: task.UserId, ChatId: task.ChatId, RoleId: task.RoleId, @@ -87,18 +99,30 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { Tokens: 0, UseContext: false, } - res := h.db.Save(&historyUserMsg) + res := h.db.Create(&message) if res.Error != nil { - logger.Error("error with save MidJourney message: ", res.Error) + logger.Error("error with save chat history message: ", res.Error) } - // delete task from leveldb - _ = h.leveldb.Delete(types.TaskStorePrefix + key) + // save the job + job.UserId = task.UserId + job.ChatId = task.ChatId + job.MessageId = data.MessageId + job.Content = data.Content + job.Prompt = data.Prompt + job.Image = utils.JsonEncode(data.Image) + job.Hash = data.Image.Hash + job.CreatedAt = time.Now() + res = h.db.Create(&job) + if res.Error != nil { + logger.Error("error with save MidJourney Job: ", res.Error) + } } // 推送消息到客户端 - wsClient := h.App.MjTaskClients.Get(key) + wsClient := h.App.MjTaskClients.Get(data.Key) if wsClient == nil { // 客户端断线,则丢弃 + logger.Errorf("Client is offline: %+v", data) resp.SUCCESS(c, "Client is offline") return } @@ -107,9 +131,48 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd}) // delete client - h.App.MjTaskClients.Delete(key) + h.App.MjTaskClients.Delete(data.Key) } else { utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) } resp.SUCCESS(c, "SUCCESS") } + +// Upscale send upscale command to MidJourney Bot +func (h *MidJourneyHandler) Upscale(c *gin.Context) { + var data struct { + Index int32 `json:"index"` + MessageId string `json:"message_id"` + MessageHash string `json:"message_hash"` + SessionId string `json:"session_id"` + Key string `json:"key"` + } + + if err := c.ShouldBindJSON(&data); err != nil || + data.SessionId == "" || + data.Key == "" { + resp.ERROR(c, types.InvalidArgs) + return + } + wsClient := h.App.ChatClients.Get(data.SessionId) + if wsClient == nil { + resp.ERROR(c, "No Websocket client online") + return + } + + err := h.mjFunc.Upscale(function.MjUpscaleReq{ + Index: data.Index, + MessageId: data.MessageId, + MessageHash: data.MessageHash, + }) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + utils.ReplyMessage(wsClient, "已推送放大图片任务到 MidJourney 机器人,请耐心等待任务执行...") + if h.App.MjTaskClients.Get(data.Key) == nil { + h.App.MjTaskClients.Put(data.Key, wsClient) + } + resp.SUCCESS(c) +} diff --git a/api/handler/user_handler.go b/api/handler/user_handler.go index 9c5225f8..245e9bd1 100644 --- a/api/handler/user_handler.go +++ b/api/handler/user_handler.go @@ -173,29 +173,8 @@ func (h *UserHandler) Login(c *gin.Context) { LoginIp: c.ClientIP(), LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()), }) - var chatConfig types.ChatConfig - err = utils.JsonDecode(user.ChatConfig, &chatConfig) - if err != nil { - resp.ERROR(c, err.Error()) - return - } - resp.SUCCESS(c, gin.H{ - "session_id": sessionId, - "id": user.Id, - "nickname": user.Nickname, - "avatar": user.Avatar, - "username": user.Username, - "tokens": user.Tokens, - "calls": user.Calls, - "expired_time": user.ExpiredTime, - "api_key": chatConfig.ApiKey, - "model": chatConfig.Model, - "temperature": chatConfig.Temperature, - "max_tokens": chatConfig.MaxTokens, - "enable_context": chatConfig.EnableContext, - "enable_history": chatConfig.EnableHistory, - }) + resp.SUCCESS(c, sessionId) } // Logout 注 销 diff --git a/api/main.go b/api/main.go index 64f61d2c..5281c65f 100644 --- a/api/main.go +++ b/api/main.go @@ -178,6 +178,7 @@ func main() { }), fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) { s.Engine.POST("/api/mj/notify", h.Notify) + s.Engine.POST("/api/mj/upscale", h.Upscale) }), // 管理后台控制器 diff --git a/api/service/function/mid_journey.go b/api/service/function/mid_journey.go index 52b83053..47455db1 100644 --- a/api/service/function/mid_journey.go +++ b/api/service/function/mid_journey.go @@ -29,7 +29,7 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { return "", errors.New("无效的 API Token") } - //logger.Infof("MJ 绘画参数:%+v", params) + logger.Infof("MJ 绘画参数:%+v", params) prompt := utils.InterfaceToString(params["prompt"]) if !utils.IsEmptyValue(params["ar"]) { prompt = fmt.Sprintf("%s --ar %s", prompt, params["ar"]) @@ -60,6 +60,31 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { return prompt, nil } +type MjUpscaleReq struct { + Index int32 `json:"index"` + MessageId string `json:"message_id"` + MessageHash string `json:"message_hash"` +} + +func (f FuncMidJourney) Upscale(upReq MjUpscaleReq) error { + url := fmt.Sprintf("%s/api/mj/upscale", f.config.ApiURL) + var res types.BizVo + r, err := f.client.R(). + SetHeader("Authorization", f.config.Token). + SetHeader("Content-Type", "application/json"). + SetBody(upReq). + SetSuccessResult(&res).Post(url) + if err != nil || r.IsErrorState() { + return fmt.Errorf("%v%v", r.String(), err) + } + + if res.Code != types.Success { + return errors.New(res.Message) + } + + return nil +} + func (f FuncMidJourney) Name() string { return f.name } diff --git a/api/store/model/mj_job.go b/api/store/model/mj_job.go new file mode 100644 index 00000000..74f09563 --- /dev/null +++ b/api/store/model/mj_job.go @@ -0,0 +1,19 @@ +package model + +import "time" + +type MidJourneyJob struct { + Id uint `gorm:"primarykey;column:id"` + UserId uint + ChatId string + MessageId string + Hash string + Content string + Prompt string + Image string + CreatedAt time.Time +} + +func (MidJourneyJob) TableName() string { + return "chatgpt_mj_jobs" +} diff --git a/database/update.sql b/database/update-v3.0.6.sql similarity index 100% rename from database/update.sql rename to database/update-v3.0.6.sql diff --git a/database/update-v3.0.7.sql b/database/update-v3.0.7.sql new file mode 100644 index 00000000..5031672a --- /dev/null +++ b/database/update-v3.0.7.sql @@ -0,0 +1,34 @@ +CREATE TABLE `chatgpt_mj_jobs` ( + `id` int NOT NULL, + `user_id` int NOT NULL COMMENT '用户 ID', + `chat_id` char(40) NOT NULL COMMENT '聊天会话 ID', + `message_id` char(40) NOT NULL COMMENT '消息 ID', + `hash` char(40) NOT NULL COMMENT '图片哈希', + `content` varchar(2000) NOT NULL COMMENT '消息内容', + `prompt` varchar(2000) NOT NULL COMMENT '会话提示词', + `image` text NOT NULL COMMENT '图片信息 json', + `created_at` datetime NOT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='MidJourney 任务表'; + +-- +-- 转储表的索引 +-- + +-- +-- 表的索引 `chatgpt_mj_jobs` +-- +ALTER TABLE `chatgpt_mj_jobs` + ADD PRIMARY KEY (`id`), + ADD UNIQUE KEY `message_id` (`message_id`), + ADD UNIQUE KEY `hash` (`hash`); + +-- +-- 在导出的表使用AUTO_INCREMENT +-- + +-- +-- 使用表AUTO_INCREMENT `chatgpt_mj_jobs` +-- +ALTER TABLE `chatgpt_mj_jobs` + MODIFY `id` int NOT NULL AUTO_INCREMENT; +COMMIT; diff --git a/web/src/components/ChatMidJourney.vue b/web/src/components/ChatMidJourney.vue index 2a963c51..593a1dd8 100644 --- a/web/src/components/ChatMidJourney.vue +++ b/web/src/components/ChatMidJourney.vue @@ -1,5 +1,5 @@