From 3f856afec88e30bbaa98b1a5d9cd8ea0e7916928 Mon Sep 17 00:00:00 2001 From: RockYang Date: Sun, 3 Mar 2024 10:40:32 +0800 Subject: [PATCH 1/2] fix: fix major bugs for unauthorized access to data --- api/core/app_server.go | 4 +-- api/handler/chatimpl/chat_item_handler.go | 11 +++++- api/handler/mj_handler.go | 35 +++++++++++++++---- api/handler/order_handler.go | 5 +-- api/handler/sd_handler.go | 42 ++++++++++++++++++----- api/main.go | 2 ++ api/service/mj/plus/client.go | 20 +++++++---- web/src/views/ImagesWall.vue | 4 +-- 8 files changed, 95 insertions(+), 28 deletions(-) diff --git a/api/core/app_server.go b/api/core/app_server.go index 95a3b22b..c166c0dc 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -150,11 +150,11 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { c.Request.URL.Path == "/api/chat/history" || c.Request.URL.Path == "/api/chat/detail" || c.Request.URL.Path == "/api/role/list" || - c.Request.URL.Path == "/api/mj/jobs" || + c.Request.URL.Path == "/api/mj/imgWall" || c.Request.URL.Path == "/api/mj/client" || c.Request.URL.Path == "/api/mj/notify" || c.Request.URL.Path == "/api/invite/hits" || - c.Request.URL.Path == "/api/sd/jobs" || + c.Request.URL.Path == "/api/sd/imgWall" || c.Request.URL.Path == "/api/sd/client" || strings.HasPrefix(c.Request.URL.Path, "/api/test") || strings.HasPrefix(c.Request.URL.Path, "/api/function/") || diff --git a/api/handler/chatimpl/chat_item_handler.go b/api/handler/chatimpl/chat_item_handler.go index 68996785..fee4a4fa 100644 --- a/api/handler/chatimpl/chat_item_handler.go +++ b/api/handler/chatimpl/chat_item_handler.go @@ -6,6 +6,7 @@ import ( "chatplus/store/vo" "chatplus/utils" "chatplus/utils/resp" + "github.com/gin-gonic/gin" "gorm.io/gorm" ) @@ -17,6 +18,13 @@ func (h *ChatHandler) List(c *gin.Context) { resp.ERROR(c, "The parameter 'user_id' is needed.") return } + + // fix: 只能读取本人的消息列表 + if uint(userId) != h.GetLoginUserId(c) { + resp.ERROR(c, "Hacker attempt, you can ONLY get yourself chats.") + return + } + var items = make([]vo.ChatItem, 0) var chats []model.ChatItem res := h.db.Where("user_id = ?", userId).Order("id DESC").Find(&chats) @@ -116,9 +124,10 @@ func (h *ChatHandler) Clear(c *gin.Context) { // History 获取聊天历史记录 func (h *ChatHandler) History(c *gin.Context) { chatId := c.Query("chat_id") // 会话 ID + userId := h.GetLoginUserId(c) var items []model.ChatMessage var messages = make([]vo.HistoryMessage, 0) - res := h.db.Where("chat_id = ?", chatId).Find(&items) + res := h.db.Where("user_id = ? AND chat_id = ?", userId, chatId).Find(&items) if res.Error != nil { resp.ERROR(c, "No history message") return diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 2a41f390..74a1cb0a 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -305,16 +305,40 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { resp.SUCCESS(c) } +// ImgWall 照片墙 +func (h *MidJourneyHandler) ImgWall(c *gin.Context) { + page := h.GetInt(c, "page", 0) + pageSize := h.GetInt(c, "page_size", 0) + err, jobs := h.getData(true, 0, page, pageSize, true) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + // JobList 获取 MJ 任务列表 func (h *MidJourneyHandler) JobList(c *gin.Context) { - status := h.GetInt(c, "status", 0) - userId := h.GetInt(c, "user_id", 0) + status := h.GetBool(c, "status") + userId := h.GetLoginUserId(c) page := h.GetInt(c, "page", 0) pageSize := h.GetInt(c, "page_size", 0) publish := h.GetBool(c, "publish") + err, jobs := h.getData(status, userId, page, pageSize, publish) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取 MJ 任务列表 +func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) { session := h.db.Session(&gorm.Session{}) - if status == 1 { + if finish { session = session.Where("progress = ?", 100).Order("id DESC") } else { session = session.Where("progress < ?", 100).Order("id ASC") @@ -333,8 +357,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) { var items []model.MidJourneyJob res := session.Find(&items) if res.Error != nil { - resp.ERROR(c, types.NoData) - return + return res.Error, nil } var jobs = make([]vo.MidJourneyJob, 0) @@ -366,7 +389,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) { jobs = append(jobs, job) } - resp.SUCCESS(c, jobs) + return nil, jobs } // Remove remove task image diff --git a/api/handler/order_handler.go b/api/handler/order_handler.go index 35d9e7ee..50327a23 100644 --- a/api/handler/order_handler.go +++ b/api/handler/order_handler.go @@ -7,6 +7,7 @@ import ( "chatplus/store/vo" "chatplus/utils" "chatplus/utils/resp" + "github.com/gin-gonic/gin" "gorm.io/gorm" ) @@ -31,8 +32,8 @@ func (h *OrderHandler) List(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } - user, _ := utils.GetLoginUser(c, h.db) - session := h.db.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", user.Id, types.OrderPaidSuccess) + userId := h.GetLoginUserId(c) + session := h.db.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess) var total int64 session.Model(&model.Order{}).Count(&total) var items []model.Order diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index aec34b01..de3665fa 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -11,10 +11,11 @@ import ( "chatplus/utils/resp" "encoding/base64" "fmt" - "github.com/gorilla/websocket" "net/http" "time" + "github.com/gorilla/websocket" + "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" "gorm.io/gorm" @@ -167,16 +168,41 @@ func (h *SdJobHandler) Image(c *gin.Context) { resp.SUCCESS(c) } -// JobList 获取 stable diffusion 任务列表 +// ImgWall 照片墙 +func (h *SdJobHandler) ImgWall(c *gin.Context) { + page := h.GetInt(c, "page", 0) + pageSize := h.GetInt(c, "page_size", 0) + err, jobs := h.getData(true, 0, page, pageSize, true) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取 SD 任务列表 func (h *SdJobHandler) JobList(c *gin.Context) { - status := h.GetInt(c, "status", 0) - userId := h.GetInt(c, "user_id", 0) + status := h.GetBool(c, "status") + userId := h.GetLoginUserId(c) page := h.GetInt(c, "page", 0) pageSize := h.GetInt(c, "page_size", 0) publish := h.GetBool(c, "publish") + err, jobs := h.getData(status, userId, page, pageSize, publish) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取 MJ 任务列表 +func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) { + session := h.db.Session(&gorm.Session{}) - if status == 1 { + if finish { session = session.Where("progress = ?", 100).Order("id DESC") } else { session = session.Where("progress < ?", 100).Order("id ASC") @@ -195,8 +221,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) { var items []model.SdJob res := session.Find(&items) if res.Error != nil { - resp.ERROR(c, types.NoData) - return + return res.Error, nil } var jobs = make([]vo.SdJob, 0) @@ -227,7 +252,8 @@ func (h *SdJobHandler) JobList(c *gin.Context) { } jobs = append(jobs, job) } - resp.SUCCESS(c, jobs) + + return nil, jobs } // Remove remove task image diff --git a/api/main.go b/api/main.go index cd1e9bda..8073a64c 100644 --- a/api/main.go +++ b/api/main.go @@ -241,6 +241,7 @@ func main() { group.POST("upscale", h.Upscale) group.POST("variation", h.Variation) group.GET("jobs", h.JobList) + group.GET("imgWall", h.ImgWall) group.POST("remove", h.Remove) group.POST("notify", h.Notify) group.POST("publish", h.Publish) @@ -250,6 +251,7 @@ func main() { group.Any("client", h.Client) group.POST("image", h.Image) group.GET("jobs", h.JobList) + group.GET("imgWall", h.ImgWall) group.POST("remove", h.Remove) group.POST("publish", h.Publish) }), diff --git a/api/service/mj/plus/client.go b/api/service/mj/plus/client.go index ebec3186..aa79672e 100644 --- a/api/service/mj/plus/client.go +++ b/api/service/mj/plus/client.go @@ -7,9 +7,10 @@ import ( "encoding/base64" "errors" "fmt" - "github.com/imroc/req/v3" "io" + "github.com/imroc/req/v3" + "github.com/gin-gonic/gin" ) @@ -90,15 +91,16 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) { SetErrorResult(&errRes). Post(apiURL) if err != nil { - if r != nil { + if r.Body != nil { errStr, _ := io.ReadAll(r.Body) - logger.Errorf("API URL: %s, 返回:%s", string(errStr), apiURL) + logger.Errorf("API 返回:%s, API URL: %s", string(errStr), apiURL) } return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) } if r.IsErrorState() { - return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + errStr, _ := io.ReadAll(r.Body) + return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr)) } return res, nil @@ -124,6 +126,10 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) { } } } + + if len(body.Base64Array) < 2 { + return ImageRes{}, errors.New("blend must use more than 2 images") + } var res ImageRes var errRes ErrRes r, err := req.C().R(). @@ -149,19 +155,19 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) { apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode) // 生成图片 Base64 编码 if len(task.ImgArr) != 2 { - return ImageRes{}, errors.New("参数错误,必须上传2张图片") + return ImageRes{}, errors.New("invalid params, swap face must pass 2 images") } var sourceBase64 string var targetBase64 string imageData, err := utils.DownloadImage(task.ImgArr[0], "") if err != nil { - logger.Error("error with download image: ", err) + return ImageRes{}, fmt.Errorf("error with download source image: %v", err) } else { sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) } imageData, err = utils.DownloadImage(task.ImgArr[1], "") if err != nil { - logger.Error("error with download image: ", err) + return ImageRes{}, fmt.Errorf("error with download target image: %v", err) } else { targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) } diff --git a/web/src/views/ImagesWall.vue b/web/src/views/ImagesWall.vue index cbb62a75..9d56236e 100644 --- a/web/src/views/ImagesWall.vue +++ b/web/src/views/ImagesWall.vue @@ -284,8 +284,8 @@ const getNext = () => { loading.value = true page.value = page.value + 1 - const url = imgType.value === "mj" ? "/api/mj/jobs" : "/api/sd/jobs" - httpGet(`${url}?status=1&page=${page.value}&page_size=${pageSize.value}&publish=true`).then(res => { + const url = imgType.value === "mj" ? "/api/mj/imgWall" : "/api/sd/imgWall" + httpGet(`${url}?page=${page.value}&page_size=${pageSize.value}`).then(res => { loading.value = false if (res.data.length === 0) { isOver.value = true From 33de83f2ac4efa25e4fef140f9593c3730d6f6d5 Mon Sep 17 00:00:00 2001 From: RockYang Date: Sun, 3 Mar 2024 19:27:22 +0800 Subject: [PATCH 2/2] feat: add removing order button in admin order list page --- api/config.sample.toml | 2 +- api/handler/admin/order_handler.go | 9 ++++++--- web/src/views/ChatPlus.vue | 14 +++----------- web/src/views/admin/Order.vue | 15 ++++++++++++++- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/api/config.sample.toml b/api/config.sample.toml index e3c9fa17..b89ca439 100644 --- a/api/config.sample.toml +++ b/api/config.sample.toml @@ -46,7 +46,7 @@ WeChatBot = false Active = "local" # 默认使用本地文件存储引擎 [OSS.Local] BasePath = "./static/upload" # 本地文件上传根路径 - BaseURL = "/static/upload" # 本地上传文件根 URL 如果是线上,则直接设置为 /static/upload 即可 + BaseURL = "http://localhost:5678/static/upload" # 本地上传文件前缀 URL,线上需要把 localhost 替换成自己的实际域名或者IP [OSS.Minio] Endpoint = "" # 如 172.22.11.200:9000 AccessKey = "" # 自己去 Minio 控制台去创建一个 Access Key diff --git a/api/handler/admin/order_handler.go b/api/handler/admin/order_handler.go index 229dd2a2..44edc839 100644 --- a/api/handler/admin/order_handler.go +++ b/api/handler/admin/order_handler.go @@ -8,6 +8,7 @@ import ( "chatplus/store/vo" "chatplus/utils" "chatplus/utils/resp" + "github.com/gin-gonic/gin" "gorm.io/gorm" ) @@ -26,6 +27,7 @@ func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler { func (h *OrderHandler) List(c *gin.Context) { var data struct { OrderNo string `json:"order_no"` + Status int `json:"status"` PayTime []string `json:"pay_time"` Page int `json:"page"` PageSize int `json:"page_size"` @@ -44,8 +46,9 @@ func (h *OrderHandler) List(c *gin.Context) { end := utils.Str2stamp(data.PayTime[1] + " 00:00:00") session = session.Where("pay_time >= ? AND pay_time <= ?", start, end) } - session = session.Where("status = ?", types.OrderPaidSuccess) - + if data.Status >= 0 { + session = session.Where("status", data.Status) + } var total int64 session.Model(&model.Order{}).Count(&total) var items []model.Order @@ -85,7 +88,7 @@ func (h *OrderHandler) Remove(c *gin.Context) { return } - res = h.db.Where("id = ?", id).Delete(&model.Order{}) + res = h.db.Unscoped().Where("id = ?", id).Delete(&model.Order{}) if res.Error != nil { resp.ERROR(c, "更新数据库失败!") return diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index 5635eac6..4cf35e15 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -249,7 +249,7 @@ import { ArrowDown, Check, Close, - Delete, Document, + Delete, Edit, Plus, Promotion, @@ -259,15 +259,7 @@ import { VideoPause } from '@element-plus/icons-vue' import 'highlight.js/styles/a11y-dark.css' -import { - dateFormat, - escapeHTML, - isMobile, - processContent, - randString, - removeArrayItem, - UUID -} from "@/utils/libs"; +import {dateFormat, escapeHTML, isMobile, processContent, randString, removeArrayItem, UUID} from "@/utils/libs"; import {ElMessage, ElMessageBox} from "element-plus"; import hl from "highlight.js"; import {getSessionId, getUserToken, removeUserToken} from "@/store/session"; @@ -361,7 +353,7 @@ onMounted(() => { notice.value = md.render(res.data['content']) const oldNotice = localStorage.getItem(noticeKey.value); // 如果公告有更新,则显示公告 - if (oldNotice !== notice.value) { + if (oldNotice !== notice.value && notice.value.length > 10) { showNotice.value = true } }).catch(e => { diff --git a/web/src/views/admin/Order.vue b/web/src/views/admin/Order.vue index 2b79ccd9..ec966186 100644 --- a/web/src/views/admin/Order.vue +++ b/web/src/views/admin/Order.vue @@ -2,6 +2,14 @@
+ + + { fetchData()