feat: midjourney drawing image function is ready

This commit is contained in:
RockYang 2023-08-14 17:59:21 +08:00
parent 6e40f92aaf
commit 5d2a1d21d5
13 changed files with 357 additions and 30 deletions

View File

@ -43,7 +43,6 @@ type ChatSession struct {
} }
type MjTask struct { type MjTask struct {
Client Client
ChatId string ChatId string
MessageId string MessageId string
MessageHash string MessageHash string
@ -63,6 +62,7 @@ type ApiError struct {
const PromptMsg = "prompt" // prompt message const PromptMsg = "prompt" // prompt message
const ReplyMsg = "reply" // reply message const ReplyMsg = "reply" // reply message
const MjMsg = "mj"
var ModelToTokens = map[string]int{ var ModelToTokens = map[string]int{
"gpt-3.5-turbo": 4096, "gpt-3.5-turbo": 4096,
@ -70,3 +70,5 @@ var ModelToTokens = map[string]int{
"gpt-4": 8192, "gpt-4": 8192,
"gpt-4-32k": 32768, "gpt-4-32k": 32768,
} }
const TaskStorePrefix = "/tasks/"

View File

@ -87,7 +87,11 @@ var InnerFunctions = []Function{
}, },
"ar": { "ar": {
Type: "string", Type: "string",
Description: "图片长宽比,如 16:9", Description: "图片长宽比,如 16:9, --ar 3:2",
},
"niji": {
Type: "string",
Description: "动漫模型版本,如 --niji 5",
}, },
}, },
Required: []string{}, Required: []string{},

View File

@ -21,7 +21,7 @@ const (
WsStart = WsMsgType("start") WsStart = WsMsgType("start")
WsMiddle = WsMsgType("middle") WsMiddle = WsMsgType("middle")
WsEnd = WsMsgType("end") WsEnd = WsMsgType("end")
WsImg = WsMsgType("img") WsMjImg = WsMsgType("mj")
) )
type BizCode int type BizCode int

View File

@ -27,7 +27,6 @@ import (
) )
const ErrorMsg = "抱歉AI 助手开小差了,请稍后再试。" const ErrorMsg = "抱歉AI 助手开小差了,请稍后再试。"
const TaskStorePrefix = "/tasks/"
type ChatHandler struct { type ChatHandler struct {
BaseHandler BaseHandler
@ -342,16 +341,16 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
content := data content := data
if functionName == types.FuncMidJourney { if functionName == types.FuncMidJourney {
key := utils.Sha256(data) key := utils.Sha256(data)
//logger.Info(data, ",", key)
// add task for MidJourney // add task for MidJourney
h.App.MjTaskClients.Put(key, ws) h.App.MjTaskClients.Put(key, ws)
task := types.MjTask{ task := types.MjTask{
UserId: userVo.Id, UserId: userVo.Id,
RoleId: role.Id, RoleId: role.Id,
Icon: role.Icon, Icon: "/images/avatar/mid_journey.png",
Client: ws,
ChatId: session.ChatId, ChatId: session.ChatId,
} }
err := h.leveldb.Put(TaskStorePrefix+key, task) err := h.leveldb.Put(types.TaskStorePrefix+key, task)
if err != nil { if err != nil {
logger.Error("error with store MidJourney task: ", err) logger.Error("error with store MidJourney task: ", err)
} }

View File

@ -3,9 +3,12 @@ package handler
import ( import (
"chatplus/core" "chatplus/core"
"chatplus/core/types" "chatplus/core/types"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils" "chatplus/utils"
"chatplus/utils/resp" "chatplus/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm"
) )
type TaskStatus string type TaskStatus string
@ -29,10 +32,12 @@ type Image struct {
type MidJourneyHandler struct { type MidJourneyHandler struct {
BaseHandler BaseHandler
leveldb *store.LevelDB
db *gorm.DB
} }
func NewMidJourneyHandler(app *core.AppServer) *MidJourneyHandler { func NewMidJourneyHandler(app *core.AppServer, leveldb *store.LevelDB, db *gorm.DB) *MidJourneyHandler {
h := MidJourneyHandler{} h := MidJourneyHandler{leveldb: leveldb, db: db}
h.App = app h.App = app
return &h return &h
} }
@ -57,18 +62,54 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
} }
key := utils.Sha256(data.Prompt) key := utils.Sha256(data.Prompt)
data.Key = key data.Key = key
// TODO: 如果绘画任务完成了则将该消息保存到当前会话的聊天历史记录 //logger.Info(data.Prompt, ",", key)
if data.Status == Finished {
var task types.MjTask
err := h.leveldb.Get(types.TaskStorePrefix+key, &task)
if err != nil {
logger.Error("error with get MidJourney task: ", err)
resp.ERROR(c)
return
}
wsClient := h.App.MjTaskClients.Get(key) // TODO: 是否需要把图片下载到本地服务器?
if wsClient == nil { // 客户端断线,则丢弃
resp.SUCCESS(c) historyUserMsg := model.HistoryMessage{
return UserId: task.UserId,
ChatId: task.ChatId,
RoleId: task.RoleId,
Type: types.MjMsg,
Icon: task.Icon,
Content: utils.JsonEncode(data),
Tokens: 0,
UseContext: false,
}
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("error with save MidJourney message: ", res.Error)
}
// delete task from leveldb
_ = h.leveldb.Delete(types.TaskStorePrefix + key)
} }
// 推送消息到客户端 // 推送消息到客户端
// TODO: 增加绘画消息类型 wsClient := h.App.MjTaskClients.Get(key)
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsImg, Content: data}) if wsClient == nil { // 客户端断线,则丢弃
resp.ERROR(c, "Error with CallBack") resp.SUCCESS(c, "Client is offline")
return
}
if data.Status == Finished {
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
// delete client
h.App.MjTaskClients.Delete(key)
} else {
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
}
resp.SUCCESS(c, "SUCCESS")
} }

View File

@ -21,7 +21,7 @@ func NewMidJourneyFunc(config types.ChatPlusExtConfig) FuncMidJourney {
return FuncMidJourney{ return FuncMidJourney{
name: "MidJourney AI 绘画", name: "MidJourney AI 绘画",
config: config, config: config,
client: req.C().SetTimeout(10 * time.Second)} client: req.C().SetTimeout(30 * time.Second)}
} }
func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
@ -29,13 +29,19 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
return "", errors.New("无效的 API Token") return "", errors.New("无效的 API Token")
} }
logger.Infof("MJ 绘画参数:%+v", params) //logger.Infof("MJ 绘画参数:%+v", params)
prompt := utils.InterfaceToString(params["prompt"]) prompt := utils.InterfaceToString(params["prompt"])
if !utils.IsEmptyValue(params["ar"]) { if !utils.IsEmptyValue(params["ar"]) {
prompt = prompt + fmt.Sprintf(" --ar %v", params["ar"]) prompt = fmt.Sprintf("%s --ar %s", prompt, params["ar"])
delete(params, "ar") delete(params, "--ar")
} }
prompt = prompt + " --niji 5" if !utils.IsEmptyValue(params["niji"]) {
prompt = fmt.Sprintf("%s --niji %s", prompt, params["niji"])
delete(params, "niji")
} else {
prompt = prompt + " --v 5.2"
}
params["prompt"] = prompt
url := fmt.Sprintf("%s/api/mj/image", f.config.ApiURL) url := fmt.Sprintf("%s/api/mj/image", f.config.ApiURL)
var res types.BizVo var res types.BizVo
r, err := f.client.R(). r, err := f.client.R().

View File

@ -89,6 +89,10 @@ func Ip2Region(searcher *xdb.Searcher, ip string) string {
} }
func IsEmptyValue(obj interface{}) bool { func IsEmptyValue(obj interface{}) bool {
if obj == nil {
return true
}
v := reflect.ValueOf(obj) v := reflect.ValueOf(obj)
switch v.Kind() { switch v.Kind() {
case reflect.Ptr, reflect.Interface: case reflect.Ptr, reflect.Interface:

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 KiB

View File

@ -0,0 +1,215 @@
<template>
<div class="chat-line chat-line-mj">
<div class="chat-line-inner">
<div class="chat-icon">
<img :src="icon" alt="User"/>
</div>
<div class="chat-item">
<div class="content">
<div class="text" v-html="data.content"></div>
<div class="images" v-if="data.image?.url !== ''">
<el-image :src="data.image?.url"
:zoom-rate="1.0"
:preview-src-list="[data.image?.url]"
:initial-index="0" lazy>
<template #placeholder>
<div class="image-slot"
:style="{height: height+'px', lineHeight:height+'px'}">
正在加载图片<span class="dot">...</span></div>
</template>
<template #error>
<div class="image-slot">
<el-icon>
<icon-picture/>
</el-icon>
</div>
</template>
</el-image>
</div>
</div>
<div class="opt" v-if="data.image?.hash !== ''">
<div class="opt-line">
<ul>
<li><a @click="upscale(1)">U1</a></li>
<li><a @click="upscale(2)">U2</a></li>
<li><a @click="upscale(3)">U3</a></li>
<li><a @click="upscale(4)">U4</a></li>
</ul>
</div>
<div class="opt-line">
<ul>
<li><a @click="variation(1)">V1</a></li>
<li><a @click="variation(2)">V2</a></li>
<li><a @click="variation(3)">V3</a></li>
<li><a @click="variation(4)">V4</a></li>
</ul>
</div>
</div>
<div class="bar" v-if="createdAt !== ''">
<span class="bar-item"><el-icon><Clock/></el-icon> {{ createdAt }}</span>
<span class="bar-item">tokens: {{ tokens }}</span>
</div>
</div>
</div>
</div>
</template>
<script setup>
import {ref, watch} from "vue";
import {Clock} from "@element-plus/icons-vue";
import {ElMessage} from "element-plus";
const props = defineProps({
content: Object,
icon: String,
createdAt: String
});
const data = ref(props.content)
console.log(data.value)
const tokens = ref(0)
const cacheKey = "img_placeholder_height"
const item = localStorage.getItem(cacheKey);
const height = ref(0)
if (item) {
height.value = parseInt(item)
}
if (data.value["image"]?.width > 0) {
height.value = 350 * data.value["image"]?.height / data.value["image"]?.width
localStorage.setItem(cacheKey, height.value)
}
watch(() => props.content, (newVal) => {
data.value = newVal;
});
const upscale = (index) => {
ElMessage.warning("当前版本暂未实现 Variation 功能!")
}
const variation = (index) => {
ElMessage.warning("当前版本暂未实现 Variation 功能!")
}
</script>
<style lang="stylus">
.chat-line-mj {
background-color #ffffff;
justify-content: center;
width 100%
padding-bottom: 1.5rem;
padding-top: 1.5rem;
border-bottom: 1px solid #d9d9e3;
.chat-line-inner {
display flex;
width 100%;
max-width 900px;
padding-left 10px;
.chat-icon {
margin-right 20px;
img {
width: 30px;
height: 30px;
border-radius: 10px;
padding: 1px;
}
}
.chat-item {
position: relative;
padding: 0 5px 0 0;
overflow: hidden;
.content {
word-break break-word;
padding: 6px 10px;
color #374151;
font-size: var(--content-font-size);
border-radius: 5px;
overflow: auto;
.text {
p:first-child {
margin-top 0
}
}
.images {
max-width 350px;
.el-image {
border-radius 10px;
.image-slot {
color #c1c1c1
width 350px
text-align center
border-radius 10px;
border 1px solid #e1e1e1
}
}
}
}
.opt {
.opt-line {
margin 6px 0
ul {
display flex
flex-flow row
padding-left 10px
li {
margin-right 10px
a {
padding 6px 0
width 64px
text-align center
border-radius 5px
display block
cursor pointer
background-color #4E5058
color #ffffff
&:hover {
background-color #6D6F78
}
}
}
}
}
}
.bar {
padding 10px;
.bar-item {
background-color #f7f7f8;
color #888
padding 3px 5px;
margin-right 10px;
border-radius 5px;
.el-icon {
position relative
top 2px;
}
}
}
}
}
}
</style>

View File

@ -64,7 +64,7 @@ export default defineComponent({
}) })
</script> </script>
<style lang="stylus" scoped> <style lang="stylus">
.chat-line-prompt { .chat-line-prompt {
background-color #ffffff; background-color #ffffff;
justify-content: center; justify-content: center;

View File

@ -22,6 +22,10 @@
:created-at="dateFormat(item['created_at'])" :created-at="dateFormat(item['created_at'])"
:tokens="item['tokens']" :tokens="item['tokens']"
:content="item.content"/> :content="item.content"/>
<chat-mid-journey v-else-if="item.type==='mj'"
:content="item.content"
:icon="item.icon"
:created-at="dateFormat(item['created_at'])"/>
</div> </div>
</div><!-- end chat box --> </div><!-- end chat box -->
</div> </div>
@ -38,6 +42,7 @@ import 'highlight.js/styles/a11y-dark.css'
import hl from "highlight.js"; import hl from "highlight.js";
import {ElMessage} from "element-plus"; import {ElMessage} from "element-plus";
import {Promotion} from "@element-plus/icons-vue"; import {Promotion} from "@element-plus/icons-vue";
import ChatMidJourney from "@/components/ChatMidJourney.vue";
const chatData = ref([]) const chatData = ref([])
const router = useRouter() const router = useRouter()
@ -57,6 +62,11 @@ httpGet('/api/chat/history?chat_id=' + chatId).then(res => {
if (data[i].type === "prompt") { if (data[i].type === "prompt") {
chatData.value.push(data[i]); chatData.value.push(data[i]);
continue; continue;
} else if (data[i].type === "mj") {
data[i].content = JSON.parse(data[i].content)
data[i].content.content = md.render(data[i].content?.content)
chatData.value.push(data[i]);
continue;
} }
data[i].orgContent = data[i].content; data[i].orgContent = data[i].content;

View File

@ -166,6 +166,10 @@
:created-at="dateFormat(item['created_at'])" :created-at="dateFormat(item['created_at'])"
:tokens="item['tokens']" :tokens="item['tokens']"
:content="item.content"/> :content="item.content"/>
<chat-mid-journey v-else-if="item.type==='mj'"
:content="item.content"
:icon="item.icon"
:created-at="dateFormat(item['created_at'])"/>
</div> </div>
</div><!-- end chat box --> </div><!-- end chat box -->
@ -277,6 +281,7 @@ import {checkSession} from "@/action/session";
import BindMobile from "@/components/BindMobile.vue"; import BindMobile from "@/components/BindMobile.vue";
import RewardVerify from "@/components/RewardVerify.vue"; import RewardVerify from "@/components/RewardVerify.vue";
import Welcome from "@/components/Welcome.vue"; import Welcome from "@/components/Welcome.vue";
import ChatMidJourney from "@/components/ChatMidJourney.vue";
const title = ref('ChatGPT-智能助手'); const title = ref('ChatGPT-智能助手');
const logo = 'images/logo.png'; const logo = 'images/logo.png';
@ -542,12 +547,39 @@ const connect = function (chat_id, role_id) {
icon: _role['icon'], icon: _role['icon'],
content: "" content: ""
}); });
} else if (data.type === 'end') { // } else if (data.type === "mj") {
canSend.value = true; canSend.value = false;
showReGenerate.value = true; showReGenerate.value = false;
showStopGenerate.value = false; showStopGenerate.value = true;
lineBuffer.value = ''; // const content = data.content;
const md = require('markdown-it')({breaks: true});
content.content = md.render(content.content)
// console.log(content)
// check if the message is in chatData
let flag = false
for (let i = 0; i < chatData.value.length; i++) {
if (chatData.value[i].id === content.key) {
console.log(chatData.value[i])
flag = true
chatData.value[i].content = content
break
}
}
if (flag === false) {
chatData.value.push({
type: "mj",
id: content.key,
icon: "/images/avatar/mid_journey.png",
content: content
});
}
if (content.status === "Finished") {
canSend.value = true;
showReGenerate.value = true;
showStopGenerate.value = false;
}
} else if (data.type === 'end') { //
// //
if (isNewChat && newChatItem.value !== null) { if (isNewChat && newChatItem.value !== null) {
newChatItem.value['title'] = previousText.value; newChatItem.value['title'] = previousText.value;
@ -556,9 +588,18 @@ const connect = function (chat_id, role_id) {
activeChat.value = newChatItem.value; activeChat.value = newChatItem.value;
newChatItem.value = null; // newChatItem.value = null; //
} }
const reply = chatData.value[chatData.value.length - 1]
if (reply.content.indexOf("绘画提示词:") === -1) {
return
}
canSend.value = true;
showReGenerate.value = true;
showStopGenerate.value = false;
lineBuffer.value = ''; //
// token // token
const reply = chatData.value[chatData.value.length - 1]
httpGet(`/api/chat/tokens?text=${reply.orgContent}&model=${model.value}`).then(res => { httpGet(`/api/chat/tokens?text=${reply.orgContent}&model=${model.value}`).then(res => {
reply['created_at'] = new Date().getTime(); reply['created_at'] = new Date().getTime();
reply['tokens'] = res.data; reply['tokens'] = res.data;
@ -723,6 +764,11 @@ const loadChatHistory = function (chatId) {
if (data[i].type === "prompt") { if (data[i].type === "prompt") {
chatData.value.push(data[i]); chatData.value.push(data[i]);
continue; continue;
} else if (data[i].type === "mj") {
data[i].content = JSON.parse(data[i].content)
data[i].content.content = md.render(data[i].content?.content)
chatData.value.push(data[i]);
continue;
} }
data[i].orgContent = data[i].content; data[i].orgContent = data[i].content;

View File

@ -65,7 +65,7 @@ const login = function () {
if (username.value === '') { if (username.value === '') {
return ElMessage.error('请输入用户名'); return ElMessage.error('请输入用户名');
} }
if (password.value.trim() === '') { if (password.value === '') {
return ElMessage.error('请输入密码'); return ElMessage.error('请输入密码');
} }