diff --git a/api/core/app_server.go b/api/core/app_server.go index 95a3b22b..c166c0dc 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -150,11 +150,11 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { c.Request.URL.Path == "/api/chat/history" || c.Request.URL.Path == "/api/chat/detail" || c.Request.URL.Path == "/api/role/list" || - c.Request.URL.Path == "/api/mj/jobs" || + c.Request.URL.Path == "/api/mj/imgWall" || c.Request.URL.Path == "/api/mj/client" || c.Request.URL.Path == "/api/mj/notify" || c.Request.URL.Path == "/api/invite/hits" || - c.Request.URL.Path == "/api/sd/jobs" || + c.Request.URL.Path == "/api/sd/imgWall" || c.Request.URL.Path == "/api/sd/client" || strings.HasPrefix(c.Request.URL.Path, "/api/test") || strings.HasPrefix(c.Request.URL.Path, "/api/function/") || diff --git a/api/handler/chatimpl/chat_item_handler.go b/api/handler/chatimpl/chat_item_handler.go index 68996785..fee4a4fa 100644 --- a/api/handler/chatimpl/chat_item_handler.go +++ b/api/handler/chatimpl/chat_item_handler.go @@ -6,6 +6,7 @@ import ( "chatplus/store/vo" "chatplus/utils" "chatplus/utils/resp" + "github.com/gin-gonic/gin" "gorm.io/gorm" ) @@ -17,6 +18,13 @@ func (h *ChatHandler) List(c *gin.Context) { resp.ERROR(c, "The parameter 'user_id' is needed.") return } + + // fix: 只能读取本人的消息列表 + if uint(userId) != h.GetLoginUserId(c) { + resp.ERROR(c, "Hacker attempt, you can ONLY get yourself chats.") + return + } + var items = make([]vo.ChatItem, 0) var chats []model.ChatItem res := h.db.Where("user_id = ?", userId).Order("id DESC").Find(&chats) @@ -116,9 +124,10 @@ func (h *ChatHandler) Clear(c *gin.Context) { // History 获取聊天历史记录 func (h *ChatHandler) History(c *gin.Context) { chatId := c.Query("chat_id") // 会话 ID + userId := h.GetLoginUserId(c) var items []model.ChatMessage var messages = make([]vo.HistoryMessage, 0) - res := h.db.Where("chat_id = ?", chatId).Find(&items) + res := h.db.Where("user_id = ? AND chat_id = ?", userId, chatId).Find(&items) if res.Error != nil { resp.ERROR(c, "No history message") return diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 2a41f390..74a1cb0a 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -305,16 +305,40 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { resp.SUCCESS(c) } +// ImgWall 照片墙 +func (h *MidJourneyHandler) ImgWall(c *gin.Context) { + page := h.GetInt(c, "page", 0) + pageSize := h.GetInt(c, "page_size", 0) + err, jobs := h.getData(true, 0, page, pageSize, true) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + // JobList 获取 MJ 任务列表 func (h *MidJourneyHandler) JobList(c *gin.Context) { - status := h.GetInt(c, "status", 0) - userId := h.GetInt(c, "user_id", 0) + status := h.GetBool(c, "status") + userId := h.GetLoginUserId(c) page := h.GetInt(c, "page", 0) pageSize := h.GetInt(c, "page_size", 0) publish := h.GetBool(c, "publish") + err, jobs := h.getData(status, userId, page, pageSize, publish) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取 MJ 任务列表 +func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) { session := h.db.Session(&gorm.Session{}) - if status == 1 { + if finish { session = session.Where("progress = ?", 100).Order("id DESC") } else { session = session.Where("progress < ?", 100).Order("id ASC") @@ -333,8 +357,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) { var items []model.MidJourneyJob res := session.Find(&items) if res.Error != nil { - resp.ERROR(c, types.NoData) - return + return res.Error, nil } var jobs = make([]vo.MidJourneyJob, 0) @@ -366,7 +389,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) { jobs = append(jobs, job) } - resp.SUCCESS(c, jobs) + return nil, jobs } // Remove remove task image diff --git a/api/handler/order_handler.go b/api/handler/order_handler.go index 35d9e7ee..50327a23 100644 --- a/api/handler/order_handler.go +++ b/api/handler/order_handler.go @@ -7,6 +7,7 @@ import ( "chatplus/store/vo" "chatplus/utils" "chatplus/utils/resp" + "github.com/gin-gonic/gin" "gorm.io/gorm" ) @@ -31,8 +32,8 @@ func (h *OrderHandler) List(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } - user, _ := utils.GetLoginUser(c, h.db) - session := h.db.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", user.Id, types.OrderPaidSuccess) + userId := h.GetLoginUserId(c) + session := h.db.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess) var total int64 session.Model(&model.Order{}).Count(&total) var items []model.Order diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index aec34b01..de3665fa 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -11,10 +11,11 @@ import ( "chatplus/utils/resp" "encoding/base64" "fmt" - "github.com/gorilla/websocket" "net/http" "time" + "github.com/gorilla/websocket" + "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" "gorm.io/gorm" @@ -167,16 +168,41 @@ func (h *SdJobHandler) Image(c *gin.Context) { resp.SUCCESS(c) } -// JobList 获取 stable diffusion 任务列表 +// ImgWall 照片墙 +func (h *SdJobHandler) ImgWall(c *gin.Context) { + page := h.GetInt(c, "page", 0) + pageSize := h.GetInt(c, "page_size", 0) + err, jobs := h.getData(true, 0, page, pageSize, true) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取 SD 任务列表 func (h *SdJobHandler) JobList(c *gin.Context) { - status := h.GetInt(c, "status", 0) - userId := h.GetInt(c, "user_id", 0) + status := h.GetBool(c, "status") + userId := h.GetLoginUserId(c) page := h.GetInt(c, "page", 0) pageSize := h.GetInt(c, "page_size", 0) publish := h.GetBool(c, "publish") + err, jobs := h.getData(status, userId, page, pageSize, publish) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取 MJ 任务列表 +func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) { + session := h.db.Session(&gorm.Session{}) - if status == 1 { + if finish { session = session.Where("progress = ?", 100).Order("id DESC") } else { session = session.Where("progress < ?", 100).Order("id ASC") @@ -195,8 +221,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) { var items []model.SdJob res := session.Find(&items) if res.Error != nil { - resp.ERROR(c, types.NoData) - return + return res.Error, nil } var jobs = make([]vo.SdJob, 0) @@ -227,7 +252,8 @@ func (h *SdJobHandler) JobList(c *gin.Context) { } jobs = append(jobs, job) } - resp.SUCCESS(c, jobs) + + return nil, jobs } // Remove remove task image diff --git a/api/main.go b/api/main.go index cd1e9bda..8073a64c 100644 --- a/api/main.go +++ b/api/main.go @@ -241,6 +241,7 @@ func main() { group.POST("upscale", h.Upscale) group.POST("variation", h.Variation) group.GET("jobs", h.JobList) + group.GET("imgWall", h.ImgWall) group.POST("remove", h.Remove) group.POST("notify", h.Notify) group.POST("publish", h.Publish) @@ -250,6 +251,7 @@ func main() { group.Any("client", h.Client) group.POST("image", h.Image) group.GET("jobs", h.JobList) + group.GET("imgWall", h.ImgWall) group.POST("remove", h.Remove) group.POST("publish", h.Publish) }), diff --git a/api/service/mj/plus/client.go b/api/service/mj/plus/client.go index ebec3186..aa79672e 100644 --- a/api/service/mj/plus/client.go +++ b/api/service/mj/plus/client.go @@ -7,9 +7,10 @@ import ( "encoding/base64" "errors" "fmt" - "github.com/imroc/req/v3" "io" + "github.com/imroc/req/v3" + "github.com/gin-gonic/gin" ) @@ -90,15 +91,16 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) { SetErrorResult(&errRes). Post(apiURL) if err != nil { - if r != nil { + if r.Body != nil { errStr, _ := io.ReadAll(r.Body) - logger.Errorf("API URL: %s, 返回:%s", string(errStr), apiURL) + logger.Errorf("API 返回:%s, API URL: %s", string(errStr), apiURL) } return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) } if r.IsErrorState() { - return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + errStr, _ := io.ReadAll(r.Body) + return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr)) } return res, nil @@ -124,6 +126,10 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) { } } } + + if len(body.Base64Array) < 2 { + return ImageRes{}, errors.New("blend must use more than 2 images") + } var res ImageRes var errRes ErrRes r, err := req.C().R(). @@ -149,19 +155,19 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) { apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode) // 生成图片 Base64 编码 if len(task.ImgArr) != 2 { - return ImageRes{}, errors.New("参数错误,必须上传2张图片") + return ImageRes{}, errors.New("invalid params, swap face must pass 2 images") } var sourceBase64 string var targetBase64 string imageData, err := utils.DownloadImage(task.ImgArr[0], "") if err != nil { - logger.Error("error with download image: ", err) + return ImageRes{}, fmt.Errorf("error with download source image: %v", err) } else { sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) } imageData, err = utils.DownloadImage(task.ImgArr[1], "") if err != nil { - logger.Error("error with download image: ", err) + return ImageRes{}, fmt.Errorf("error with download target image: %v", err) } else { targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) } diff --git a/web/src/views/ImagesWall.vue b/web/src/views/ImagesWall.vue index cbb62a75..9d56236e 100644 --- a/web/src/views/ImagesWall.vue +++ b/web/src/views/ImagesWall.vue @@ -284,8 +284,8 @@ const getNext = () => { loading.value = true page.value = page.value + 1 - const url = imgType.value === "mj" ? "/api/mj/jobs" : "/api/sd/jobs" - httpGet(`${url}?status=1&page=${page.value}&page_size=${pageSize.value}&publish=true`).then(res => { + const url = imgType.value === "mj" ? "/api/mj/imgWall" : "/api/sd/imgWall" + httpGet(`${url}?page=${page.value}&page_size=${pageSize.value}`).then(res => { loading.value = false if (res.data.length === 0) { isOver.value = true