mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
feat: midjourney image variation function is ready
This commit is contained in:
parent
113769791f
commit
e65c32c4c7
@ -83,15 +83,15 @@ var InnerFunctions = []Function{
|
|||||||
Properties: map[string]Property{
|
Properties: map[string]Property{
|
||||||
"prompt": {
|
"prompt": {
|
||||||
Type: "string",
|
Type: "string",
|
||||||
Description: "绘画内容描述,提示词,此参数需要翻译成英文",
|
Description: "绘画内容描述,提示词,如果该参数中有中文的话,则需要翻译成英文",
|
||||||
},
|
},
|
||||||
"ar": {
|
"ar": {
|
||||||
Type: "string",
|
Type: "string",
|
||||||
Description: "图片长宽比,如 16:9, --ar 3:2",
|
Description: "图片长宽比,如 --ar 4:3",
|
||||||
},
|
},
|
||||||
"niji": {
|
"niji": {
|
||||||
Type: "string",
|
Type: "string",
|
||||||
Description: "动漫模型版本,如 --niji 5",
|
Description: "动漫模型版本,例如 --niji 5",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Required: []string{},
|
Required: []string{},
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"time"
|
"time"
|
||||||
@ -83,7 +84,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
|||||||
err := h.leveldb.Get(types.TaskStorePrefix+data.Key, &task)
|
err := h.leveldb.Get(types.TaskStorePrefix+data.Key, &task)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("error with get MidJourney task: ", err)
|
logger.Error("error with get MidJourney task: ", err)
|
||||||
resp.ERROR(c, err.Error())
|
resp.SUCCESS(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,16 +139,18 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
|||||||
resp.SUCCESS(c, "SUCCESS")
|
resp.SUCCESS(c, "SUCCESS")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upscale send upscale command to MidJourney Bot
|
type reqVo struct {
|
||||||
func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
|
||||||
var data struct {
|
|
||||||
Index int32 `json:"index"`
|
Index int32 `json:"index"`
|
||||||
MessageId string `json:"message_id"`
|
MessageId string `json:"message_id"`
|
||||||
MessageHash string `json:"message_hash"`
|
MessageHash string `json:"message_hash"`
|
||||||
SessionId string `json:"session_id"`
|
SessionId string `json:"session_id"`
|
||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
}
|
Prompt string `json:"prompt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upscale send upscale command to MidJourney Bot
|
||||||
|
func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||||
|
var data reqVo
|
||||||
if err := c.ShouldBindJSON(&data); err != nil ||
|
if err := c.ShouldBindJSON(&data); err != nil ||
|
||||||
data.SessionId == "" ||
|
data.SessionId == "" ||
|
||||||
data.Key == "" {
|
data.Key == "" {
|
||||||
@ -170,7 +173,39 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.ReplyMessage(wsClient, "已推送放大图片任务到 MidJourney 机器人,请耐心等待任务执行...")
|
content := fmt.Sprintf("**%s** 已推送 Upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
|
||||||
|
utils.ReplyMessage(wsClient, content)
|
||||||
|
if h.App.MjTaskClients.Get(data.Key) == nil {
|
||||||
|
h.App.MjTaskClients.Put(data.Key, wsClient)
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||||
|
var data reqVo
|
||||||
|
if err := c.ShouldBindJSON(&data); err != nil ||
|
||||||
|
data.SessionId == "" ||
|
||||||
|
data.Key == "" {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wsClient := h.App.ChatClients.Get(data.SessionId)
|
||||||
|
if wsClient == nil {
|
||||||
|
resp.ERROR(c, "No Websocket client online")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.mjFunc.Variation(function.MjVariationReq{
|
||||||
|
Index: data.Index,
|
||||||
|
MessageId: data.MessageId,
|
||||||
|
MessageHash: data.MessageHash,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
content := fmt.Sprintf("**%s** 已推送 Variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
|
||||||
|
utils.ReplyMessage(wsClient, content)
|
||||||
if h.App.MjTaskClients.Get(data.Key) == nil {
|
if h.App.MjTaskClients.Get(data.Key) == nil {
|
||||||
h.App.MjTaskClients.Put(data.Key, wsClient)
|
h.App.MjTaskClients.Put(data.Key, wsClient)
|
||||||
}
|
}
|
||||||
|
@ -179,6 +179,7 @@ func main() {
|
|||||||
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
|
||||||
s.Engine.POST("/api/mj/notify", h.Notify)
|
s.Engine.POST("/api/mj/notify", h.Notify)
|
||||||
s.Engine.POST("/api/mj/upscale", h.Upscale)
|
s.Engine.POST("/api/mj/upscale", h.Upscale)
|
||||||
|
s.Engine.POST("/api/mj/variation", h.Variation)
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// 管理后台控制器
|
// 管理后台控制器
|
||||||
|
@ -85,6 +85,31 @@ func (f FuncMidJourney) Upscale(upReq MjUpscaleReq) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MjVariationReq struct {
|
||||||
|
Index int32 `json:"index"`
|
||||||
|
MessageId string `json:"message_id"`
|
||||||
|
MessageHash string `json:"message_hash"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f FuncMidJourney) Variation(upReq MjVariationReq) error {
|
||||||
|
url := fmt.Sprintf("%s/api/mj/variation", f.config.ApiURL)
|
||||||
|
var res types.BizVo
|
||||||
|
r, err := f.client.R().
|
||||||
|
SetHeader("Authorization", f.config.Token).
|
||||||
|
SetHeader("Content-Type", "application/json").
|
||||||
|
SetBody(upReq).
|
||||||
|
SetSuccessResult(&res).Post(url)
|
||||||
|
if err != nil || r.IsErrorState() {
|
||||||
|
return fmt.Errorf("%v%v", r.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != types.Success {
|
||||||
|
return errors.New(res.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (f FuncMidJourney) Name() string {
|
func (f FuncMidJourney) Name() string {
|
||||||
return f.name
|
return f.name
|
||||||
}
|
}
|
||||||
|
@ -92,14 +92,23 @@ watch(() => props.content, (newVal) => {
|
|||||||
});
|
});
|
||||||
const emits = defineEmits(['disable-input', 'disable-input']);
|
const emits = defineEmits(['disable-input', 'disable-input']);
|
||||||
const upscale = (index) => {
|
const upscale = (index) => {
|
||||||
|
send('/api/mj/upscale', index)
|
||||||
|
}
|
||||||
|
|
||||||
|
const variation = (index) => {
|
||||||
|
send('/api/mj/variation', index)
|
||||||
|
}
|
||||||
|
|
||||||
|
const send = (url, index) => {
|
||||||
loading.value = true
|
loading.value = true
|
||||||
emits('disable-input')
|
emits('disable-input')
|
||||||
httpPost("/api/mj/upscale", {
|
httpPost(url, {
|
||||||
index: index,
|
index: index,
|
||||||
message_id: data.value?.["message_id"],
|
message_id: data.value?.["message_id"],
|
||||||
message_hash: data.value?.["image"]?.hash,
|
message_hash: data.value?.["image"]?.hash,
|
||||||
session_id: getSessionId(),
|
session_id: getSessionId(),
|
||||||
key: data.value?.["key"]
|
key: data.value?.["key"],
|
||||||
|
prompt: data.value?.["prompt"],
|
||||||
}).then(() => {
|
}).then(() => {
|
||||||
ElMessage.success("任务推送成功,请耐心等待任务执行...")
|
ElMessage.success("任务推送成功,请耐心等待任务执行...")
|
||||||
loading.value = false
|
loading.value = false
|
||||||
@ -108,10 +117,6 @@ const upscale = (index) => {
|
|||||||
emits('disable-input')
|
emits('disable-input')
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
const variation = (index) => {
|
|
||||||
ElMessage.warning("当前版本暂未实现 Variation 功能!")
|
|
||||||
}
|
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<style lang="stylus">
|
<style lang="stylus">
|
||||||
|
@ -554,6 +554,12 @@ const connect = function (chat_id, role_id) {
|
|||||||
const content = data.content;
|
const content = data.content;
|
||||||
const md = require('markdown-it')({breaks: true});
|
const md = require('markdown-it')({breaks: true});
|
||||||
content.content = md.render(content.content)
|
content.content = md.render(content.content)
|
||||||
|
let key = content.key
|
||||||
|
// fixed bug: 执行 Upscale 和 Variation 操作的时候覆盖之前的绘画
|
||||||
|
if (content.status === "Finished") {
|
||||||
|
key = randString(32)
|
||||||
|
enableInput()
|
||||||
|
}
|
||||||
// console.log(content)
|
// console.log(content)
|
||||||
// check if the message is in chatData
|
// check if the message is in chatData
|
||||||
let flag = false
|
let flag = false
|
||||||
@ -562,21 +568,19 @@ const connect = function (chat_id, role_id) {
|
|||||||
console.log(chatData.value[i])
|
console.log(chatData.value[i])
|
||||||
flag = true
|
flag = true
|
||||||
chatData.value[i].content = content
|
chatData.value[i].content = content
|
||||||
|
chatData.value[i].id = key
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (flag === false) {
|
if (flag === false) {
|
||||||
chatData.value.push({
|
chatData.value.push({
|
||||||
type: "mj",
|
type: "mj",
|
||||||
id: content["message_id"],
|
id: key,
|
||||||
icon: "/images/avatar/mid_journey.png",
|
icon: "/images/avatar/mid_journey.png",
|
||||||
content: content
|
content: content
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (content.status === "Finished") {
|
|
||||||
enableInput()
|
|
||||||
}
|
|
||||||
} else if (data.type === 'end') { // 消息接收完毕
|
} else if (data.type === 'end') { // 消息接收完毕
|
||||||
// 追加当前会话到会话列表
|
// 追加当前会话到会话列表
|
||||||
if (isNewChat && newChatItem.value !== null) {
|
if (isNewChat && newChatItem.value !== null) {
|
||||||
|
@ -44,7 +44,6 @@ import {onMounted, ref} from "vue";
|
|||||||
import {Lock, UserFilled} from "@element-plus/icons-vue";
|
import {Lock, UserFilled} from "@element-plus/icons-vue";
|
||||||
import {httpPost} from "@/utils/http";
|
import {httpPost} from "@/utils/http";
|
||||||
import {ElMessage} from "element-plus";
|
import {ElMessage} from "element-plus";
|
||||||
import {setSession} from "@/store/session";
|
|
||||||
import {useRouter} from "vue-router";
|
import {useRouter} from "vue-router";
|
||||||
import FooterBar from "@/components/FooterBar.vue";
|
import FooterBar from "@/components/FooterBar.vue";
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user