mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-20 10:16:39 +08:00
fix: fix major bugs for unauthorized access to data
This commit is contained in:
parent
2c7d472069
commit
2ac44cdeb6
@ -150,11 +150,11 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
|||||||
c.Request.URL.Path == "/api/chat/history" ||
|
c.Request.URL.Path == "/api/chat/history" ||
|
||||||
c.Request.URL.Path == "/api/chat/detail" ||
|
c.Request.URL.Path == "/api/chat/detail" ||
|
||||||
c.Request.URL.Path == "/api/role/list" ||
|
c.Request.URL.Path == "/api/role/list" ||
|
||||||
c.Request.URL.Path == "/api/mj/jobs" ||
|
c.Request.URL.Path == "/api/mj/imgWall" ||
|
||||||
c.Request.URL.Path == "/api/mj/client" ||
|
c.Request.URL.Path == "/api/mj/client" ||
|
||||||
c.Request.URL.Path == "/api/mj/notify" ||
|
c.Request.URL.Path == "/api/mj/notify" ||
|
||||||
c.Request.URL.Path == "/api/invite/hits" ||
|
c.Request.URL.Path == "/api/invite/hits" ||
|
||||||
c.Request.URL.Path == "/api/sd/jobs" ||
|
c.Request.URL.Path == "/api/sd/imgWall" ||
|
||||||
c.Request.URL.Path == "/api/sd/client" ||
|
c.Request.URL.Path == "/api/sd/client" ||
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
|
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@ -17,6 +18,13 @@ func (h *ChatHandler) List(c *gin.Context) {
|
|||||||
resp.ERROR(c, "The parameter 'user_id' is needed.")
|
resp.ERROR(c, "The parameter 'user_id' is needed.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fix: 只能读取本人的消息列表
|
||||||
|
if uint(userId) != h.GetLoginUserId(c) {
|
||||||
|
resp.ERROR(c, "Hacker attempt, you can ONLY get yourself chats.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var items = make([]vo.ChatItem, 0)
|
var items = make([]vo.ChatItem, 0)
|
||||||
var chats []model.ChatItem
|
var chats []model.ChatItem
|
||||||
res := h.db.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
|
res := h.db.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
|
||||||
@ -116,9 +124,10 @@ func (h *ChatHandler) Clear(c *gin.Context) {
|
|||||||
// History 获取聊天历史记录
|
// History 获取聊天历史记录
|
||||||
func (h *ChatHandler) History(c *gin.Context) {
|
func (h *ChatHandler) History(c *gin.Context) {
|
||||||
chatId := c.Query("chat_id") // 会话 ID
|
chatId := c.Query("chat_id") // 会话 ID
|
||||||
|
userId := h.GetLoginUserId(c)
|
||||||
var items []model.ChatMessage
|
var items []model.ChatMessage
|
||||||
var messages = make([]vo.HistoryMessage, 0)
|
var messages = make([]vo.HistoryMessage, 0)
|
||||||
res := h.db.Where("chat_id = ?", chatId).Find(&items)
|
res := h.db.Where("user_id = ? AND chat_id = ?", userId, chatId).Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, "No history message")
|
resp.ERROR(c, "No history message")
|
||||||
return
|
return
|
||||||
|
@ -305,16 +305,40 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
|||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImgWall 照片墙
|
||||||
|
func (h *MidJourneyHandler) ImgWall(c *gin.Context) {
|
||||||
|
page := h.GetInt(c, "page", 0)
|
||||||
|
pageSize := h.GetInt(c, "page_size", 0)
|
||||||
|
err, jobs := h.getData(true, 0, page, pageSize, true)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, jobs)
|
||||||
|
}
|
||||||
|
|
||||||
// JobList 获取 MJ 任务列表
|
// JobList 获取 MJ 任务列表
|
||||||
func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
||||||
status := h.GetInt(c, "status", 0)
|
status := h.GetBool(c, "status")
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
userId := h.GetLoginUserId(c)
|
||||||
page := h.GetInt(c, "page", 0)
|
page := h.GetInt(c, "page", 0)
|
||||||
pageSize := h.GetInt(c, "page_size", 0)
|
pageSize := h.GetInt(c, "page_size", 0)
|
||||||
publish := h.GetBool(c, "publish")
|
publish := h.GetBool(c, "publish")
|
||||||
|
|
||||||
|
err, jobs := h.getData(status, userId, page, pageSize, publish)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, jobs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobList 获取 MJ 任务列表
|
||||||
|
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) {
|
||||||
session := h.db.Session(&gorm.Session{})
|
session := h.db.Session(&gorm.Session{})
|
||||||
if status == 1 {
|
if finish {
|
||||||
session = session.Where("progress = ?", 100).Order("id DESC")
|
session = session.Where("progress = ?", 100).Order("id DESC")
|
||||||
} else {
|
} else {
|
||||||
session = session.Where("progress < ?", 100).Order("id ASC")
|
session = session.Where("progress < ?", 100).Order("id ASC")
|
||||||
@ -333,8 +357,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
|||||||
var items []model.MidJourneyJob
|
var items []model.MidJourneyJob
|
||||||
res := session.Find(&items)
|
res := session.Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, types.NoData)
|
return res.Error, nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var jobs = make([]vo.MidJourneyJob, 0)
|
var jobs = make([]vo.MidJourneyJob, 0)
|
||||||
@ -366,7 +389,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
|||||||
|
|
||||||
jobs = append(jobs, job)
|
jobs = append(jobs, job)
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c, jobs)
|
return nil, jobs
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove remove task image
|
// Remove remove task image
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@ -31,8 +32,8 @@ func (h *OrderHandler) List(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user, _ := utils.GetLoginUser(c, h.db)
|
userId := h.GetLoginUserId(c)
|
||||||
session := h.db.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", user.Id, types.OrderPaidSuccess)
|
session := h.db.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess)
|
||||||
var total int64
|
var total int64
|
||||||
session.Model(&model.Order{}).Count(&total)
|
session.Model(&model.Order{}).Count(&total)
|
||||||
var items []model.Order
|
var items []model.Order
|
||||||
|
@ -11,10 +11,11 @@ import (
|
|||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@ -167,16 +168,41 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
|||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// JobList 获取 stable diffusion 任务列表
|
// ImgWall 照片墙
|
||||||
|
func (h *SdJobHandler) ImgWall(c *gin.Context) {
|
||||||
|
page := h.GetInt(c, "page", 0)
|
||||||
|
pageSize := h.GetInt(c, "page_size", 0)
|
||||||
|
err, jobs := h.getData(true, 0, page, pageSize, true)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, jobs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobList 获取 SD 任务列表
|
||||||
func (h *SdJobHandler) JobList(c *gin.Context) {
|
func (h *SdJobHandler) JobList(c *gin.Context) {
|
||||||
status := h.GetInt(c, "status", 0)
|
status := h.GetBool(c, "status")
|
||||||
userId := h.GetInt(c, "user_id", 0)
|
userId := h.GetLoginUserId(c)
|
||||||
page := h.GetInt(c, "page", 0)
|
page := h.GetInt(c, "page", 0)
|
||||||
pageSize := h.GetInt(c, "page_size", 0)
|
pageSize := h.GetInt(c, "page_size", 0)
|
||||||
publish := h.GetBool(c, "publish")
|
publish := h.GetBool(c, "publish")
|
||||||
|
|
||||||
|
err, jobs := h.getData(status, userId, page, pageSize, publish)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SUCCESS(c, jobs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobList 获取 MJ 任务列表
|
||||||
|
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) {
|
||||||
|
|
||||||
session := h.db.Session(&gorm.Session{})
|
session := h.db.Session(&gorm.Session{})
|
||||||
if status == 1 {
|
if finish {
|
||||||
session = session.Where("progress = ?", 100).Order("id DESC")
|
session = session.Where("progress = ?", 100).Order("id DESC")
|
||||||
} else {
|
} else {
|
||||||
session = session.Where("progress < ?", 100).Order("id ASC")
|
session = session.Where("progress < ?", 100).Order("id ASC")
|
||||||
@ -195,8 +221,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
|
|||||||
var items []model.SdJob
|
var items []model.SdJob
|
||||||
res := session.Find(&items)
|
res := session.Find(&items)
|
||||||
if res.Error != nil {
|
if res.Error != nil {
|
||||||
resp.ERROR(c, types.NoData)
|
return res.Error, nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var jobs = make([]vo.SdJob, 0)
|
var jobs = make([]vo.SdJob, 0)
|
||||||
@ -227,7 +252,8 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
jobs = append(jobs, job)
|
jobs = append(jobs, job)
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c, jobs)
|
|
||||||
|
return nil, jobs
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove remove task image
|
// Remove remove task image
|
||||||
|
@ -241,6 +241,7 @@ func main() {
|
|||||||
group.POST("upscale", h.Upscale)
|
group.POST("upscale", h.Upscale)
|
||||||
group.POST("variation", h.Variation)
|
group.POST("variation", h.Variation)
|
||||||
group.GET("jobs", h.JobList)
|
group.GET("jobs", h.JobList)
|
||||||
|
group.GET("imgWall", h.ImgWall)
|
||||||
group.POST("remove", h.Remove)
|
group.POST("remove", h.Remove)
|
||||||
group.POST("notify", h.Notify)
|
group.POST("notify", h.Notify)
|
||||||
group.POST("publish", h.Publish)
|
group.POST("publish", h.Publish)
|
||||||
@ -250,6 +251,7 @@ func main() {
|
|||||||
group.Any("client", h.Client)
|
group.Any("client", h.Client)
|
||||||
group.POST("image", h.Image)
|
group.POST("image", h.Image)
|
||||||
group.GET("jobs", h.JobList)
|
group.GET("jobs", h.JobList)
|
||||||
|
group.GET("imgWall", h.ImgWall)
|
||||||
group.POST("remove", h.Remove)
|
group.POST("remove", h.Remove)
|
||||||
group.POST("publish", h.Publish)
|
group.POST("publish", h.Publish)
|
||||||
}),
|
}),
|
||||||
|
@ -7,9 +7,10 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/imroc/req/v3"
|
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -90,15 +91,16 @@ func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
|
|||||||
SetErrorResult(&errRes).
|
SetErrorResult(&errRes).
|
||||||
Post(apiURL)
|
Post(apiURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if r != nil {
|
if r.Body != nil {
|
||||||
errStr, _ := io.ReadAll(r.Body)
|
errStr, _ := io.ReadAll(r.Body)
|
||||||
logger.Errorf("API URL: %s, 返回:%s", string(errStr), apiURL)
|
logger.Errorf("API 返回:%s, API URL: %s", string(errStr), apiURL)
|
||||||
}
|
}
|
||||||
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.IsErrorState() {
|
if r.IsErrorState() {
|
||||||
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
|
errStr, _ := io.ReadAll(r.Body)
|
||||||
|
return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr))
|
||||||
}
|
}
|
||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
@ -124,6 +126,10 @@ func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(body.Base64Array) < 2 {
|
||||||
|
return ImageRes{}, errors.New("blend must use more than 2 images")
|
||||||
|
}
|
||||||
var res ImageRes
|
var res ImageRes
|
||||||
var errRes ErrRes
|
var errRes ErrRes
|
||||||
r, err := req.C().R().
|
r, err := req.C().R().
|
||||||
@ -149,19 +155,19 @@ func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
|
|||||||
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
|
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
|
||||||
// 生成图片 Base64 编码
|
// 生成图片 Base64 编码
|
||||||
if len(task.ImgArr) != 2 {
|
if len(task.ImgArr) != 2 {
|
||||||
return ImageRes{}, errors.New("参数错误,必须上传2张图片")
|
return ImageRes{}, errors.New("invalid params, swap face must pass 2 images")
|
||||||
}
|
}
|
||||||
var sourceBase64 string
|
var sourceBase64 string
|
||||||
var targetBase64 string
|
var targetBase64 string
|
||||||
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
|
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("error with download image: ", err)
|
return ImageRes{}, fmt.Errorf("error with download source image: %v", err)
|
||||||
} else {
|
} else {
|
||||||
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
||||||
}
|
}
|
||||||
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
|
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("error with download image: ", err)
|
return ImageRes{}, fmt.Errorf("error with download target image: %v", err)
|
||||||
} else {
|
} else {
|
||||||
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
||||||
}
|
}
|
||||||
|
@ -284,8 +284,8 @@ const getNext = () => {
|
|||||||
|
|
||||||
loading.value = true
|
loading.value = true
|
||||||
page.value = page.value + 1
|
page.value = page.value + 1
|
||||||
const url = imgType.value === "mj" ? "/api/mj/jobs" : "/api/sd/jobs"
|
const url = imgType.value === "mj" ? "/api/mj/imgWall" : "/api/sd/imgWall"
|
||||||
httpGet(`${url}?status=1&page=${page.value}&page_size=${pageSize.value}&publish=true`).then(res => {
|
httpGet(`${url}?page=${page.value}&page_size=${pageSize.value}`).then(res => {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
if (res.data.length === 0) {
|
if (res.data.length === 0) {
|
||||||
isOver.value = true
|
isOver.value = true
|
||||||
|
Loading…
Reference in New Issue
Block a user