From 5d2a1d21d5a560fb243b29b1645904f957e576d6 Mon Sep 17 00:00:00 2001 From: RockYang Date: Mon, 14 Aug 2023 17:59:21 +0800 Subject: [PATCH] feat: midjourney drawing image function is ready --- api/core/types/chat.go | 4 +- api/core/types/function.go | 6 +- api/core/types/web.go | 2 +- api/handler/chat_handler.go | 7 +- api/handler/mj_handler.go | 61 +++++-- api/service/function/mid_journey.go | 16 +- api/utils/common.go | 4 + web/public/images/avatar/mid_journey.png | Bin 0 -> 5713 bytes web/src/components/ChatMidJourney.vue | 215 +++++++++++++++++++++++ web/src/components/ChatPrompt.vue | 2 +- web/src/views/ChatExport.vue | 10 ++ web/src/views/ChatPlus.vue | 58 +++++- web/src/views/admin/Login.vue | 2 +- 13 files changed, 357 insertions(+), 30 deletions(-) create mode 100644 web/public/images/avatar/mid_journey.png create mode 100644 web/src/components/ChatMidJourney.vue diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 163a0aea..efcd344e 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -43,7 +43,6 @@ type ChatSession struct { } type MjTask struct { - Client Client ChatId string MessageId string MessageHash string @@ -63,6 +62,7 @@ type ApiError struct { const PromptMsg = "prompt" // prompt message const ReplyMsg = "reply" // reply message +const MjMsg = "mj" var ModelToTokens = map[string]int{ "gpt-3.5-turbo": 4096, @@ -70,3 +70,5 @@ var ModelToTokens = map[string]int{ "gpt-4": 8192, "gpt-4-32k": 32768, } + +const TaskStorePrefix = "/tasks/" diff --git a/api/core/types/function.go b/api/core/types/function.go index a4361263..cc92dcd9 100644 --- a/api/core/types/function.go +++ b/api/core/types/function.go @@ -87,7 +87,11 @@ var InnerFunctions = []Function{ }, "ar": { Type: "string", - Description: "图片长宽比,如 16:9", + Description: "图片长宽比,如 16:9, --ar 3:2", + }, + "niji": { + Type: "string", + Description: "动漫模型版本,如 --niji 5", }, }, Required: []string{}, diff --git a/api/core/types/web.go b/api/core/types/web.go index d100e592..43ed032c 100644 --- a/api/core/types/web.go +++ b/api/core/types/web.go @@ -21,7 +21,7 @@ const ( WsStart = WsMsgType("start") WsMiddle = WsMsgType("middle") WsEnd = WsMsgType("end") - WsImg = WsMsgType("img") + WsMjImg = WsMsgType("mj") ) type BizCode int diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go index dc61173b..2ad6427c 100644 --- a/api/handler/chat_handler.go +++ b/api/handler/chat_handler.go @@ -27,7 +27,6 @@ import ( ) const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。" -const TaskStorePrefix = "/tasks/" type ChatHandler struct { BaseHandler @@ -342,16 +341,16 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession content := data if functionName == types.FuncMidJourney { key := utils.Sha256(data) + //logger.Info(data, ",", key) // add task for MidJourney h.App.MjTaskClients.Put(key, ws) task := types.MjTask{ UserId: userVo.Id, RoleId: role.Id, - Icon: role.Icon, - Client: ws, + Icon: "/images/avatar/mid_journey.png", ChatId: session.ChatId, } - err := h.leveldb.Put(TaskStorePrefix+key, task) + err := h.leveldb.Put(types.TaskStorePrefix+key, task) if err != nil { logger.Error("error with store MidJourney task: ", err) } diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 365e5191..493b6a0c 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -3,9 +3,12 @@ package handler import ( "chatplus/core" "chatplus/core/types" + "chatplus/store" + "chatplus/store/model" "chatplus/utils" "chatplus/utils/resp" "github.com/gin-gonic/gin" + "gorm.io/gorm" ) type TaskStatus string @@ -29,10 +32,12 @@ type Image struct { type MidJourneyHandler struct { BaseHandler + leveldb *store.LevelDB + db *gorm.DB } -func NewMidJourneyHandler(app *core.AppServer) *MidJourneyHandler { - h := MidJourneyHandler{} +func NewMidJourneyHandler(app *core.AppServer, leveldb *store.LevelDB, db *gorm.DB) *MidJourneyHandler { + h := MidJourneyHandler{leveldb: leveldb, db: db} h.App = app return &h } @@ -57,18 +62,54 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } + key := utils.Sha256(data.Prompt) data.Key = key - // TODO: 如果绘画任务完成了则将该消息保存到当前会话的聊天历史记录 + //logger.Info(data.Prompt, ",", key) + if data.Status == Finished { + var task types.MjTask + err := h.leveldb.Get(types.TaskStorePrefix+key, &task) + if err != nil { + logger.Error("error with get MidJourney task: ", err) + resp.ERROR(c) + return + } - wsClient := h.App.MjTaskClients.Get(key) - if wsClient == nil { // 客户端断线,则丢弃 - resp.SUCCESS(c) - return + // TODO: 是否需要把图片下载到本地服务器? + + historyUserMsg := model.HistoryMessage{ + UserId: task.UserId, + ChatId: task.ChatId, + RoleId: task.RoleId, + Type: types.MjMsg, + Icon: task.Icon, + Content: utils.JsonEncode(data), + Tokens: 0, + UseContext: false, + } + res := h.db.Save(&historyUserMsg) + if res.Error != nil { + logger.Error("error with save MidJourney message: ", res.Error) + } + + // delete task from leveldb + _ = h.leveldb.Delete(types.TaskStorePrefix + key) } // 推送消息到客户端 - // TODO: 增加绘画消息类型 - utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsImg, Content: data}) - resp.ERROR(c, "Error with CallBack") + wsClient := h.App.MjTaskClients.Get(key) + if wsClient == nil { // 客户端断线,则丢弃 + resp.SUCCESS(c, "Client is offline") + return + } + + if data.Status == Finished { + utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) + utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd}) + // delete client + h.App.MjTaskClients.Delete(key) + } else { + utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) + } + resp.SUCCESS(c, "SUCCESS") } diff --git a/api/service/function/mid_journey.go b/api/service/function/mid_journey.go index 2c510e3f..52b83053 100644 --- a/api/service/function/mid_journey.go +++ b/api/service/function/mid_journey.go @@ -21,7 +21,7 @@ func NewMidJourneyFunc(config types.ChatPlusExtConfig) FuncMidJourney { return FuncMidJourney{ name: "MidJourney AI 绘画", config: config, - client: req.C().SetTimeout(10 * time.Second)} + client: req.C().SetTimeout(30 * time.Second)} } func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { @@ -29,13 +29,19 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { return "", errors.New("无效的 API Token") } - logger.Infof("MJ 绘画参数:%+v", params) + //logger.Infof("MJ 绘画参数:%+v", params) prompt := utils.InterfaceToString(params["prompt"]) if !utils.IsEmptyValue(params["ar"]) { - prompt = prompt + fmt.Sprintf(" --ar %v", params["ar"]) - delete(params, "ar") + prompt = fmt.Sprintf("%s --ar %s", prompt, params["ar"]) + delete(params, "--ar") } - prompt = prompt + " --niji 5" + if !utils.IsEmptyValue(params["niji"]) { + prompt = fmt.Sprintf("%s --niji %s", prompt, params["niji"]) + delete(params, "niji") + } else { + prompt = prompt + " --v 5.2" + } + params["prompt"] = prompt url := fmt.Sprintf("%s/api/mj/image", f.config.ApiURL) var res types.BizVo r, err := f.client.R(). diff --git a/api/utils/common.go b/api/utils/common.go index 4924c587..a88c286d 100644 --- a/api/utils/common.go +++ b/api/utils/common.go @@ -89,6 +89,10 @@ func Ip2Region(searcher *xdb.Searcher, ip string) string { } func IsEmptyValue(obj interface{}) bool { + if obj == nil { + return true + } + v := reflect.ValueOf(obj) switch v.Kind() { case reflect.Ptr, reflect.Interface: diff --git a/web/public/images/avatar/mid_journey.png b/web/public/images/avatar/mid_journey.png new file mode 100644 index 0000000000000000000000000000000000000000..e239bb094a4d429430825639c5257279cb5c62e8 GIT binary patch literal 5713 zcmbVQ2{=@3`yWCGNtCTjgDlOMG1eLTPO{5Vq?s`o%wlHjYr>RWb}HG+Ua}XJ5F=|L zQBgu7MJP-E(OX^L_rKn6xxRB<=bY!9=Xd|^^|@!BbJ^0|h?`TG6952kn;7d`Gw!Cl zCp$Ca+5ld>&$w}rj2$Te0MFyylPUcKj|c$3+JdunpgNeHLSl$`Ig~3AjgEeA>wz;4B4cfW%xy72UKoTcNJ|r_ z9)M&J;IUK`FaYmOpdbS@KtFkrjPdSkCFg>5I+$4k^@wCFP(=<7!N6cJ zpbA1x9<8i|QbB-$a2Om8g(*Yj6(I5|NJV9&ydv<|0b-~lySgE*^$mY1V?1eqJg8I> z5(@SA_m}fmkRy`aq4Ed>0t$me;cy580igsEsHgx4fg=8!K_5%OkZ~j`jz|FRGNRB# zU#bR(Vd`H#;7NaI36x)PVnhrYfFeQV+o z{2wgIgGeP(Jc$2-`p@(~DPY9b%BBHlCV%~EQRPx z#$XNn7-ov^dP73$k+CQ$k!(vOdjHOpZXAHl8|q#Kcp zM=>16;Zg2bD2d<>0{#gjQjh3OBr_B=Ybv!TgrDAmbQI3+4Sk^Xx9T-I72W<0y<|1^!wQ zHdvouBX1n=XB8k(nBD5u0AY61iggA3dX4*=2>f>*{6*{Ufn|{Xn{@vLqY&Mw{wOk5 z$DI+a|5#4Ye`P!c<@fJR|L>CbuTB4_^7~&*{}on@2a4d1W%NoYXxAC&Zd3UyDbW9) zub*@O(13o!8C`96_@|#T9{zM>EP+8yX7t-F&A4^|fLGl_U&l7!>S9h{hz+*6E01DM zHeV|?$tf{gzgC~dd@{d+RY)KOS6PV7jh$n4XEPFdSl$kbb<&A@r243aU$msh8f%lm zSq!;zsM&H{SEYGqNO{wbT)%0wMALzSBz)eJwYIlYEHr1;KdpQVeLed$S4s*BukK>~ zQ60V0!o4*uED^GyWW2HLmApJP3Y==M)_K%&tGYSl##F)U$}QdLK>1EvpW6!`g-(Zr zI4n6eX}pE_OC#F`+|kh^oeR=f0Rw{-(i9Zc2mRuf9-G#@aaB~b ztJKiY>CEkW)SZe~`#Qq>-$MhNhFQ$7AIg%qtF+703P1OWYFq8%a{s;Akn$TBN?1x} zly0=LrdrF58``DA>FMc^{?MJyspt99Ayhl`U|a5exOR|or`aORtjRm<&NkqS8}vRz7XoqIe5la>UYa>AB%X#En;R?? zXTP_6@dv!2;pE#wmRRk0+Cuo(#=@rjuXBTm;(J-hx_n1zz#q-#?NYj}pz}AHb>C6l z;*aj(ibJVu=H(?PCturM2+gY?Y6RP5@XByo;}skDAhW_*293#^Mf#KZ=RofiE7cz$)_b=o2Cat(j0q64NNK7e+%`W2}D zRl~P+lP#4`&2(9hB?wPzCFd#6?>h#7Y1{4041a_69;6pP&|{L0M@67~-aL)9C>vY$ z!Hk-*h!6$&8Wung zUR?eHbNnHCOn)Ev?6p|VC*wVCwA4e#hQ{~-$%foaEhz;zZJDe>H%R@(D<4)5FlUcj zsQRZ=M>QI|H91MvjU5dRl2yfL#a$S0RtOJS9DfD6-z{O}qz?qu+;#C*K{4A2XL15! zi>3ten_(|Ud>2NXRmnw*TL?$*ELG_V3w^DO>Zl0dTF$<;Sh@mbHsWH`wt@ShHd#;a z1A^S_mPJvb=>nP|WNLW(d~CV0x`ei<@IJtu4Oe%f`{%wRQs!4xLzf>sA1;*)s?!&H zD^?=bv%jStj)y`yh1X%5i^feAVG9{&1*Ae$R@!Fv>RvyiC%+~h82G-RfQ3t7 zIf6SX3m19Az5Uqoamek@0(Q5{uy0GG_?l9E#oCAqUV8xfy1DhPxab#maKk0lZ9B-$ z#j5Xfg)Y1=yR_2+CF4kftKS*fW1_cGWubv-ck3HY6lvR96yVGb8LiSS zsH^)vnVv@f&{7^gzM{veY0hjZD?2}bEcQ0iFkqj4gEtF@j_CCQfF$RIlX3c^#+Sdu zwaQjigO2z$Yjz64nOYQTwKx!?hg$==5ToR^q~EAajB=4!_-es4c4wZ<`I`l zb)Dtcm?|%G>pYC#+zAQ9Ln@1N#w3JW8X7C%pB8STtPAb1W(o*vLEi@%#uU<)*7VrMDmT|8JW z>@eAhfzUOzUVmg?s_ilrNOF8p^d?{7OsrQk&mt>;rQ;DfPo`R0jkLiFaK&OZRMaa8 zT|9a=#bWoEfUQ6DJiV)A`;4`guS;8f8hA_(@Z{&e%h4ES@@^gfac!G>G$AUxv?v=k zf#mW%ey7UgLaMKmMir;fld;B=6+APl^&gXJFU2g_mwPfjn5UK$6}_1+G^%_VRq#sv zM2k(>&0~u-2hals-M->`0v`^rmU!!&k zRvX52vJ`#I@beYh5C-tF9Dx&2&R?Hv1h3VAO?e+zIk!}WT8-I|2mFJ>fJKC_=lL#~ zBrzZ9i-pltZw?96-@GCF!`OZAoLNO5Zod)UA_B)YFWhcv1-HIA=O2}O+el){`SrE> zb$>tvQc^tVyN#OPM{QmfgTd9+Rn~ono-J-oP}X0c^0-GmevAXGx-e~u3~1J?Ta(PY zR{i1hi!9rt@SNIGpAyt1nS`27I^Hwk$Q4!n@#n`IUmTXSePLR**@j>Tu*XXF=7z7% z4SxR~WHJ7{-2Nrh*?Y37!zfsDy%>=2QcFl+;S@c-iD~P3u&|$t@_OXZ^fNwJKF^an zo5S)eu-;p#<5p%a7RDCkA39GLPsNB*m|tK1lAT+>lq@uC6;CUF7L!B2Ihd85ZDN(? z06xw>?O&U>XQFA!VwRg;>GO^9w+ZE&oLJS6rw}VYSjn1cj4oIX*T-kA6*ljHqnqCG%sAA60NPo-QO!>zJEn)k4Rf<=6Xj z&iM(?pS;GG?sE5}md?d8CeqU7imu>=DrCG0llFn@65KqR0n5O={)RJ^faKIv@i~uA zT0RqAG|dC!gvmKC4#KX zy!);qy?=|aVwwy9?$rhW0_CI{YFCRYi~CY-v~bh;dNe1wOb4)=M-1@Ny?gi8zL2X` zPUo?2nu{)~`&$d;%&|exqsHE6iqBS}*;&!n`HL`;&RoTTq32vSm@c&d)W@?TnepK$ z&=bK!m5kXU*AbCZE*U zGK;G}ztO?3T1H5$1CmDRQefwpde3o!yQuSd1?J_A2YlzJA55_&DyueV?p)@b?JAOd zyx;7q`q$f5N$LewUIysT8+T1qsucD203tweM=cr3kDcwH`{!iShZHM^=kL_kS_tsZ zmAD2DP6t?wg$7j5SQ|VV<2YPRrEHs1%FU2S$Af0H#FH74J2}<1UtBp)N$c~b70YCh zDf|W^+!gI3E{?um&TQXqBg})lk7&a2r-EWcGtZ2-$Ydr!lW(Z2Q+u=e?b_@uEw!&m z06gl>SrQwAAcWB=-VD zOO9?PEOb)b*E2I-fn{bse>RaO%3q}Q5>`5fJc%(R&~s0h2_kJ(#Lo(4Pj`&v@pauM(6TWe zor?wDZX;%JIVq!N!Ip&j7ZZEfG76jw3+FqzGfHGi7BFC&I)oN9u9%egg zEo7o5z>97Zayir*!WAEc!xdLi<&pWKBRrj+$2RvGmv=lH>9kU$4{#ayhuQbu0eAT{ z46{e;UeJEv%e{b^DL!TsE}%wv-#`6uPo?z`!Bs}Z3@JF&-F&qo z*VWsubLzaP-^dI(x3}oU#95oSK4ae|eG=a;-Hj`d&g(a_vX#p;mEcP-6wjn92`4y! zGw3s*6ln*x#zc#R_j}jVW(u%nE~`W(y?JHPgS68TX&XIjA&qonV&8g!DIip3pZiC~1?2*79n8>a@fEbh|OGiM~EJ#m5aDSn6&qW&W$#h5D z%E}6cLLRca2OW%>EM3`@sd z+MX9Qv`RCUXbxTQGgCikaYaIh_qu)JTV58QmPysbJrT1Bvjwz_{@FbdPu)_DC9Y2a z9t^}wQB1w)J?SUb7hg&pic?6smPV6&rIl;6A0}iHcdop;T1+1Oc;oY174I<^jFed` zdHu4K3ArOy;GmRwZ*Q+to#$x_izHgl#jb&PcR9d=Dbl@_KDGJz`QcIvmuBj%haSt` z%7Xh91i6X#x9vPVJyEEscipLUI^A)6((dkk{HKwf=9P8gW|-F#%c{<{8Gq+SUUjmB z;Tx5c=G{TTocs3K*w{?blq#xj^>ugWgFtn9G(NcF=KQd1`-73nvy52{`k2jG7M-%P zvM*o0cppe%l9y+X^|QA%CioFw)YW z7S>3eAKIA6HmQcDL_cPV`07R(eyxe$E+9S4QEF;xA~#K>Qr>nZ5L>^w>D&2Dm&a*% z^$+=Od+wYW^8NrlQTuYy$Hyn`5F|&*Yh-D%#U*IIc+8z_ozYs%F2W3W-{+;|^=f7x z%WWvsvLk8rax(cvkSWh-=uTkoYUv~Lmw=5Uc^c!B-;>Oq8YLY+KQ~_I#dN;=isUuf zI=_wOn%Vb($ljMJiw0`Qo}L2l5iC6;U;6u0^L}X7sT7+M{XR#MyE(w)#CRC~wPwM6 zJUs)#8I68%2m74I{au<2s=rP9QhGzXY6k8{*^T*usYZ7Svy@rO%E%txS70@&?}D|Acu&@Ljj|n57{b#y%W{{@_)O6AS%Vl@gj>?3=6Ur9tIwuUVEVmE7uz zd@yzJlW<`9w7zdxSXcK!0Pp50^N+m~H{x;r{f^G(evDd}lq>nHfwxl?#Ilm_iX6GW zc&COX-e}8~3ot+^E&eJCcmhgY3N=nCo_t@A)RxDQ_>&q(nOI}Uj!2OUnBrsd>xo5^ z=4bsC@&~qTX&-5&-+y!|1#o`OEQ#?7a>wYzEm`g-oLZKv+&rseKO OUlRj!{VLt_k^cvE9PeWQ literal 0 HcmV?d00001 diff --git a/web/src/components/ChatMidJourney.vue b/web/src/components/ChatMidJourney.vue new file mode 100644 index 00000000..2a963c51 --- /dev/null +++ b/web/src/components/ChatMidJourney.vue @@ -0,0 +1,215 @@ + + + + + \ No newline at end of file diff --git a/web/src/components/ChatPrompt.vue b/web/src/components/ChatPrompt.vue index cfc848bb..5f56e675 100644 --- a/web/src/components/ChatPrompt.vue +++ b/web/src/components/ChatPrompt.vue @@ -64,7 +64,7 @@ export default defineComponent({ }) -