mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-24 04:06:39 +08:00
fix socket connect for mj task notify
This commit is contained in:
parent
f307b8ba7a
commit
59f316b341
@ -159,7 +159,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
|||||||
var tokenString string
|
var tokenString string
|
||||||
if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API
|
if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API
|
||||||
tokenString = c.GetHeader(types.AdminAuthHeader)
|
tokenString = c.GetHeader(types.AdminAuthHeader)
|
||||||
} else if c.Request.URL.Path == "/api/chat/new" {
|
} else if c.Request.URL.Path == "/api/chat/new" || c.Request.URL.Path == "/api/mj/client" {
|
||||||
tokenString = c.Query("token")
|
tokenString = c.Query("token")
|
||||||
} else {
|
} else {
|
||||||
tokenString = c.GetHeader(types.UserAuthHeader)
|
tokenString = c.GetHeader(types.UserAuthHeader)
|
||||||
|
@ -86,11 +86,10 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
|
|||||||
|
|
||||||
sessionId := c.Query("session_id")
|
sessionId := c.Query("session_id")
|
||||||
client := types.NewWsClient(ws)
|
client := types.NewWsClient(ws)
|
||||||
// 关闭旧的连接
|
// 删除旧的连接
|
||||||
if h.clients.Has(sessionId) {
|
h.clients.Delete(sessionId)
|
||||||
h.clients.Get(sessionId).Close()
|
|
||||||
}
|
|
||||||
h.clients.Put(sessionId, client)
|
h.clients.Put(sessionId, client)
|
||||||
|
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
||||||
@ -265,9 +264,30 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 更新用户剩余绘图次数
|
||||||
|
if data.Status == Finished && task.Type != service.Upscale {
|
||||||
|
h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||||
|
}
|
||||||
|
|
||||||
return nil, true
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
|
||||||
|
user, err := utils.GetLoginUser(c, h.db)
|
||||||
|
if err != nil {
|
||||||
|
resp.NotAuth(c)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.ImgCalls <= 0 {
|
||||||
|
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// Image 创建一个绘画任务
|
// Image 创建一个绘画任务
|
||||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
@ -286,6 +306,10 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.checkLimits(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var prompt = data.Prompt
|
var prompt = data.Prompt
|
||||||
if data.Rate != "" && !strings.Contains(prompt, "--ar") {
|
if data.Rate != "" && !strings.Contains(prompt, "--ar") {
|
||||||
prompt += " --ar " + data.Rate
|
prompt += " --ar " + data.Rate
|
||||||
@ -367,6 +391,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.checkLimits(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
idValue, _ := c.Get(types.LoginUserID)
|
idValue, _ := c.Get(types.LoginUserID)
|
||||||
jobId := 0
|
jobId := 0
|
||||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||||
@ -432,6 +460,10 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.checkLimits(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
idValue, _ := c.Get(types.LoginUserID)
|
idValue, _ := c.Get(types.LoginUserID)
|
||||||
jobId := 0
|
jobId := 0
|
||||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||||
|
@ -195,6 +195,7 @@ func main() {
|
|||||||
group.POST("upscale", h.Upscale)
|
group.POST("upscale", h.Upscale)
|
||||||
group.POST("variation", h.Variation)
|
group.POST("variation", h.Variation)
|
||||||
group.GET("jobs", h.JobList)
|
group.GET("jobs", h.JobList)
|
||||||
|
group.Any("client", h.Client)
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// 管理后台控制器
|
// 管理后台控制器
|
||||||
|
@ -322,6 +322,13 @@
|
|||||||
img {
|
img {
|
||||||
height 304px
|
height 304px
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.el-image-viewer__wrapper {
|
||||||
|
img {
|
||||||
|
width auto
|
||||||
|
height auto
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -285,7 +285,7 @@
|
|||||||
placement="top-start"
|
placement="top-start"
|
||||||
title="提示词"
|
title="提示词"
|
||||||
:width="240"
|
:width="240"
|
||||||
trigger="click"
|
trigger="hover"
|
||||||
>
|
>
|
||||||
<template #reference>
|
<template #reference>
|
||||||
<el-image :src="scope.item.img_url" :class="scope.item.type === 'upscale'?'upscale':''"
|
<el-image :src="scope.item.img_url" :class="scope.item.type === 'upscale'?'upscale':''"
|
||||||
@ -352,7 +352,7 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup>
|
<script setup>
|
||||||
import {nextTick, onMounted, ref} from "vue"
|
import {onMounted, ref} from "vue"
|
||||||
import {DeleteFilled, DocumentCopy, InfoFilled, Picture, Plus} from "@element-plus/icons-vue";
|
import {DeleteFilled, DocumentCopy, InfoFilled, Picture, Plus} from "@element-plus/icons-vue";
|
||||||
import Compressor from "compressorjs";
|
import Compressor from "compressorjs";
|
||||||
import {httpGet, httpPost} from "@/utils/http";
|
import {httpGet, httpPost} from "@/utils/http";
|
||||||
@ -362,8 +362,6 @@ import Clipboard from "clipboard";
|
|||||||
import {checkSession} from "@/action/session";
|
import {checkSession} from "@/action/session";
|
||||||
import {useRouter} from "vue-router";
|
import {useRouter} from "vue-router";
|
||||||
import {getSessionId, getUserToken} from "@/store/session";
|
import {getSessionId, getUserToken} from "@/store/session";
|
||||||
import {randString} from "@/utils/libs";
|
|
||||||
import hl from "highlight.js";
|
|
||||||
|
|
||||||
const listBoxHeight = ref(window.innerHeight - 40)
|
const listBoxHeight = ref(window.innerHeight - 40)
|
||||||
const mjBoxHeight = ref(window.innerHeight - 150)
|
const mjBoxHeight = ref(window.innerHeight - 150)
|
||||||
@ -409,7 +407,7 @@ const connect = () => {
|
|||||||
host = 'ws://' + location.host;
|
host = 'ws://' + location.host;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}`);
|
const _socket = new WebSocket(host + `/api/mj/client?session_id=${getSessionId()}&token=${getUserToken()}`);
|
||||||
_socket.addEventListener('open', () => {
|
_socket.addEventListener('open', () => {
|
||||||
socket.value = _socket;
|
socket.value = _socket;
|
||||||
});
|
});
|
||||||
@ -421,7 +419,7 @@ const connect = () => {
|
|||||||
reader.onload = () => {
|
reader.onload = () => {
|
||||||
const data = JSON.parse(String(reader.result));
|
const data = JSON.parse(String(reader.result));
|
||||||
console.log(data)
|
console.log(data)
|
||||||
let isNew = false
|
let isNew = true
|
||||||
if (data.progress === 100) {
|
if (data.progress === 100) {
|
||||||
for (let i = 0; i < finishedJobs.value.length; i++) {
|
for (let i = 0; i < finishedJobs.value.length; i++) {
|
||||||
if (finishedJobs.value[i].id === data.id) {
|
if (finishedJobs.value[i].id === data.id) {
|
||||||
@ -429,11 +427,17 @@ const connect = () => {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for (let i = 0; i < runningJobs.value.length; i++) {
|
||||||
|
if (runningJobs.value[i].id === data.id) {
|
||||||
|
runningJobs.value.splice(i, 1)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
if (isNew) {
|
if (isNew) {
|
||||||
finishedJobs.value.unshift(data)
|
finishedJobs.value.unshift(data)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (let i = 0; i < runningJobs.value; i++) {
|
for (let i = 0; i < runningJobs.value.length; i++) {
|
||||||
if (runningJobs.value[i].id === data.id) {
|
if (runningJobs.value[i].id === data.id) {
|
||||||
isNew = false
|
isNew = false
|
||||||
runningJobs.value[i] = data
|
runningJobs.value[i] = data
|
||||||
@ -480,6 +484,9 @@ onMounted(() => {
|
|||||||
router.push('/login')
|
router.push('/login')
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// 连接 socket
|
||||||
|
connect();
|
||||||
|
|
||||||
const clipboard = new Clipboard('.copy-prompt');
|
const clipboard = new Clipboard('.copy-prompt');
|
||||||
clipboard.on('success', () => {
|
clipboard.on('success', () => {
|
||||||
ElMessage.success({message: "复制成功!", duration: 500});
|
ElMessage.success({message: "复制成功!", duration: 500});
|
||||||
|
Loading…
Reference in New Issue
Block a user