diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 163a0aea..efcd344e 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -43,7 +43,6 @@ type ChatSession struct { } type MjTask struct { - Client Client ChatId string MessageId string MessageHash string @@ -63,6 +62,7 @@ type ApiError struct { const PromptMsg = "prompt" // prompt message const ReplyMsg = "reply" // reply message +const MjMsg = "mj" var ModelToTokens = map[string]int{ "gpt-3.5-turbo": 4096, @@ -70,3 +70,5 @@ var ModelToTokens = map[string]int{ "gpt-4": 8192, "gpt-4-32k": 32768, } + +const TaskStorePrefix = "/tasks/" diff --git a/api/core/types/function.go b/api/core/types/function.go index a4361263..cc92dcd9 100644 --- a/api/core/types/function.go +++ b/api/core/types/function.go @@ -87,7 +87,11 @@ var InnerFunctions = []Function{ }, "ar": { Type: "string", - Description: "图片长宽比,如 16:9", + Description: "图片长宽比,如 16:9, --ar 3:2", + }, + "niji": { + Type: "string", + Description: "动漫模型版本,如 --niji 5", }, }, Required: []string{}, diff --git a/api/core/types/web.go b/api/core/types/web.go index d100e592..43ed032c 100644 --- a/api/core/types/web.go +++ b/api/core/types/web.go @@ -21,7 +21,7 @@ const ( WsStart = WsMsgType("start") WsMiddle = WsMsgType("middle") WsEnd = WsMsgType("end") - WsImg = WsMsgType("img") + WsMjImg = WsMsgType("mj") ) type BizCode int diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go index dc61173b..2ad6427c 100644 --- a/api/handler/chat_handler.go +++ b/api/handler/chat_handler.go @@ -27,7 +27,6 @@ import ( ) const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。" -const TaskStorePrefix = "/tasks/" type ChatHandler struct { BaseHandler @@ -342,16 +341,16 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession content := data if functionName == types.FuncMidJourney { key := utils.Sha256(data) + //logger.Info(data, ",", key) // add task for MidJourney h.App.MjTaskClients.Put(key, ws) task := types.MjTask{ UserId: userVo.Id, RoleId: role.Id, - Icon: role.Icon, - Client: ws, + Icon: "/images/avatar/mid_journey.png", ChatId: session.ChatId, } - err := h.leveldb.Put(TaskStorePrefix+key, task) + err := h.leveldb.Put(types.TaskStorePrefix+key, task) if err != nil { logger.Error("error with store MidJourney task: ", err) } diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 365e5191..493b6a0c 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -3,9 +3,12 @@ package handler import ( "chatplus/core" "chatplus/core/types" + "chatplus/store" + "chatplus/store/model" "chatplus/utils" "chatplus/utils/resp" "github.com/gin-gonic/gin" + "gorm.io/gorm" ) type TaskStatus string @@ -29,10 +32,12 @@ type Image struct { type MidJourneyHandler struct { BaseHandler + leveldb *store.LevelDB + db *gorm.DB } -func NewMidJourneyHandler(app *core.AppServer) *MidJourneyHandler { - h := MidJourneyHandler{} +func NewMidJourneyHandler(app *core.AppServer, leveldb *store.LevelDB, db *gorm.DB) *MidJourneyHandler { + h := MidJourneyHandler{leveldb: leveldb, db: db} h.App = app return &h } @@ -57,18 +62,54 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } + key := utils.Sha256(data.Prompt) data.Key = key - // TODO: 如果绘画任务完成了则将该消息保存到当前会话的聊天历史记录 + //logger.Info(data.Prompt, ",", key) + if data.Status == Finished { + var task types.MjTask + err := h.leveldb.Get(types.TaskStorePrefix+key, &task) + if err != nil { + logger.Error("error with get MidJourney task: ", err) + resp.ERROR(c) + return + } - wsClient := h.App.MjTaskClients.Get(key) - if wsClient == nil { // 客户端断线,则丢弃 - resp.SUCCESS(c) - return + // TODO: 是否需要把图片下载到本地服务器? + + historyUserMsg := model.HistoryMessage{ + UserId: task.UserId, + ChatId: task.ChatId, + RoleId: task.RoleId, + Type: types.MjMsg, + Icon: task.Icon, + Content: utils.JsonEncode(data), + Tokens: 0, + UseContext: false, + } + res := h.db.Save(&historyUserMsg) + if res.Error != nil { + logger.Error("error with save MidJourney message: ", res.Error) + } + + // delete task from leveldb + _ = h.leveldb.Delete(types.TaskStorePrefix + key) } // 推送消息到客户端 - // TODO: 增加绘画消息类型 - utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsImg, Content: data}) - resp.ERROR(c, "Error with CallBack") + wsClient := h.App.MjTaskClients.Get(key) + if wsClient == nil { // 客户端断线,则丢弃 + resp.SUCCESS(c, "Client is offline") + return + } + + if data.Status == Finished { + utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) + utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd}) + // delete client + h.App.MjTaskClients.Delete(key) + } else { + utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) + } + resp.SUCCESS(c, "SUCCESS") } diff --git a/api/service/function/mid_journey.go b/api/service/function/mid_journey.go index 2c510e3f..52b83053 100644 --- a/api/service/function/mid_journey.go +++ b/api/service/function/mid_journey.go @@ -21,7 +21,7 @@ func NewMidJourneyFunc(config types.ChatPlusExtConfig) FuncMidJourney { return FuncMidJourney{ name: "MidJourney AI 绘画", config: config, - client: req.C().SetTimeout(10 * time.Second)} + client: req.C().SetTimeout(30 * time.Second)} } func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { @@ -29,13 +29,19 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { return "", errors.New("无效的 API Token") } - logger.Infof("MJ 绘画参数:%+v", params) + //logger.Infof("MJ 绘画参数:%+v", params) prompt := utils.InterfaceToString(params["prompt"]) if !utils.IsEmptyValue(params["ar"]) { - prompt = prompt + fmt.Sprintf(" --ar %v", params["ar"]) - delete(params, "ar") + prompt = fmt.Sprintf("%s --ar %s", prompt, params["ar"]) + delete(params, "--ar") } - prompt = prompt + " --niji 5" + if !utils.IsEmptyValue(params["niji"]) { + prompt = fmt.Sprintf("%s --niji %s", prompt, params["niji"]) + delete(params, "niji") + } else { + prompt = prompt + " --v 5.2" + } + params["prompt"] = prompt url := fmt.Sprintf("%s/api/mj/image", f.config.ApiURL) var res types.BizVo r, err := f.client.R(). diff --git a/api/utils/common.go b/api/utils/common.go index 4924c587..a88c286d 100644 --- a/api/utils/common.go +++ b/api/utils/common.go @@ -89,6 +89,10 @@ func Ip2Region(searcher *xdb.Searcher, ip string) string { } func IsEmptyValue(obj interface{}) bool { + if obj == nil { + return true + } + v := reflect.ValueOf(obj) switch v.Kind() { case reflect.Ptr, reflect.Interface: diff --git a/web/public/images/avatar/mid_journey.png b/web/public/images/avatar/mid_journey.png new file mode 100644 index 00000000..e239bb09 Binary files /dev/null and b/web/public/images/avatar/mid_journey.png differ diff --git a/web/src/components/ChatMidJourney.vue b/web/src/components/ChatMidJourney.vue new file mode 100644 index 00000000..2a963c51 --- /dev/null +++ b/web/src/components/ChatMidJourney.vue @@ -0,0 +1,215 @@ + + + + + + + + + + + + + + + 正在加载图片... + + + + + + + + + + + + + + + + + U1 + U2 + U3 + U4 + + + + + + V1 + V2 + V3 + V4 + + + + + + {{ createdAt }} + tokens: {{ tokens }} + + + + + + + + + + \ No newline at end of file diff --git a/web/src/components/ChatPrompt.vue b/web/src/components/ChatPrompt.vue index cfc848bb..5f56e675 100644 --- a/web/src/components/ChatPrompt.vue +++ b/web/src/components/ChatPrompt.vue @@ -64,7 +64,7 @@ export default defineComponent({ }) -