mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-12-26 01:55:58 +08:00
feat: support CDN reverse proxy for MidJourney and OpenAI API
This commit is contained in:
@@ -13,7 +13,9 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -58,6 +60,27 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||
|
||||
}
|
||||
|
||||
// Client WebSocket 客户端,用于通知任务状态变更
|
||||
func (h *MidJourneyHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
if userId == 0 {
|
||||
logger.Info("Invalid user ID")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
h.pool.Clients.Put(uint(userId), client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
}
|
||||
|
||||
// Image 创建一个绘画任务
|
||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
var data struct {
|
||||
@@ -147,6 +170,9 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
UserId: userId,
|
||||
})
|
||||
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
|
||||
// update user's img calls
|
||||
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||
resp.SUCCESS(c)
|
||||
@@ -205,6 +231,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
MessageId: data.MessageId,
|
||||
MessageHash: data.MessageHash,
|
||||
})
|
||||
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -226,6 +256,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
|
||||
job := model.MidJourneyJob{
|
||||
Type: types.TaskVariation.String(),
|
||||
ChannelId: data.ChannelId,
|
||||
ReferenceId: data.MessageId,
|
||||
UserId: userId,
|
||||
TaskId: data.TaskId,
|
||||
@@ -250,6 +281,9 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
MessageHash: data.MessageHash,
|
||||
})
|
||||
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
|
||||
// update user's img calls
|
||||
h.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||
resp.SUCCESS(c)
|
||||
@@ -320,6 +354,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
||||
func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint `json:"id"`
|
||||
UserId uint `json:"user_id"`
|
||||
ImgURL string `json:"img_url"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
@@ -340,5 +375,8 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
||||
logger.Error("remove image failed: ", err)
|
||||
}
|
||||
|
||||
client := h.pool.Clients.Get(data.UserId)
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user