diff --git a/api/core/app_server.go b/api/core/app_server.go index abb0f31d..96747bc3 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -147,6 +147,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { c.Request.URL.Path == "/api/mj/notify" || c.Request.URL.Path == "/api/chat/history" || c.Request.URL.Path == "/api/chat/detail" || + c.Request.URL.Path == "/api/mj/proxy" || strings.HasPrefix(c.Request.URL.Path, "/api/sms/") || strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") || strings.HasPrefix(c.Request.URL.Path, "/static/") || diff --git a/api/core/types/locked_map.go b/api/core/types/locked_map.go index c69098ff..13915c43 100644 --- a/api/core/types/locked_map.go +++ b/api/core/types/locked_map.go @@ -9,7 +9,7 @@ type MKey interface { string | int } type MValue interface { - *WsClient | *ChatSession | context.CancelFunc | []interface{} | MjTask + *WsClient | *ChatSession | context.CancelFunc | []interface{} } type LMap[K MKey, T MValue] struct { lock sync.RWMutex diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 414157c9..d90f5e4f 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -4,7 +4,6 @@ import ( "chatplus/core" "chatplus/core/types" "chatplus/service" - "chatplus/service/function" "chatplus/service/oss" "chatplus/store/model" "chatplus/utils" @@ -22,8 +21,6 @@ import ( type TaskStatus string const ( - Start = TaskStatus("Started") - Running = TaskStatus("Running") Stopped = TaskStatus("Stopped") Finished = TaskStatus("Finished") ) @@ -64,42 +61,58 @@ func NewMidJourneyHandler( return &h } +type notifyData struct { + MessageId string `json:"message_id"` + ReferenceId string `json:"reference_id"` + Image Image `json:"image"` + Content string `json:"content"` + Prompt string `json:"prompt"` + Status TaskStatus `json:"status"` + Progress int `json:"progress"` +} + func (h *MidJourneyHandler) Notify(c *gin.Context) { token := c.GetHeader("Authorization") if token != h.App.Config.ExtConfig.Token { resp.NotAuth(c) return } - - var data struct { - MessageId string `json:"message_id"` - ReferenceId string `json:"reference_id"` - Image Image `json:"image"` - Content string `json:"content"` - Prompt string `json:"prompt"` - Status TaskStatus `json:"status"` - Progress int `json:"progress"` - } + var data notifyData if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" { resp.ERROR(c, types.InvalidArgs) return } - logger.Debugf("收到 MidJourney 回调请求:%+v", data) + h.lock.Lock() defer h.lock.Unlock() - taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result() + err := h.notifyHandler(c, data) if err != nil { - resp.SUCCESS(c) // 过期任务,丢弃 + resp.ERROR(c, err.Error()) return } + // 解除任务锁定 + if data.Status == Finished || data.Status == Stopped { + h.redis.Del(c, service.MjRunningJobKey) + } + resp.SUCCESS(c) + +} + +func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) error { + taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result() + if err != nil { // 过期任务,丢弃 + logger.Warn("任务已过期:", err) + return nil + } + var task service.MjTask err = utils.JsonDecode(taskString, &task) - if err != nil { - resp.SUCCESS(c) // 非标准任务,丢弃 - return + if err != nil { // 非标准任务,丢弃 + logger.Warn("任务解析失败:", err) + return nil } if task.Src == service.TaskSrcImg { // 绘画任务 @@ -107,19 +120,20 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { var job model.MidJourneyJob res := h.db.First(&job, task.Id) if res.Error != nil { - resp.SUCCESS(c) // 非法任务,丢弃 - return + logger.Warn("非法任务:", err) + return nil } job.MessageId = data.MessageId job.ReferenceId = data.ReferenceId job.Progress = data.Progress + job.Prompt = data.Prompt // download image if data.Progress == 100 { imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL) if err != nil { - resp.ERROR(c, "error with download img: "+err.Error()) - return + logger.Error("error with download img: ", err.Error()) + return err } job.ImgURL = imgURL } else { @@ -128,18 +142,16 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { } res = h.db.Updates(&job) if res.Error != nil { - resp.ERROR(c, "error with update job: "+err.Error()) - return + logger.Error("error with update job: ", err.Error()) + return res.Error } - resp.SUCCESS(c) - } else if task.Src == service.TaskSrcChat { // 聊天任务 var job model.MidJourneyJob res := h.db.Where("message_id = ?", data.MessageId).First(&job) if res.Error == nil { - resp.SUCCESS(c) - return + logger.Warn("重复消息:", data.MessageId) + return nil } wsClient := h.App.MjTaskClients.Get(task.Id) @@ -156,10 +168,10 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error()) utils.ReplyMessage(wsClient, content) } - resp.ERROR(c, err.Error()) - return + return err } + tx := h.db.Begin() data.Image.URL = imgURL message := model.HistoryMessage{ UserId: uint(task.UserId), @@ -171,9 +183,9 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { Tokens: 0, UseContext: false, } - res := h.db.Create(&message) + res = tx.Create(&message) if res.Error != nil { - logger.Error("error with save chat history message: ", res.Error) + return res.Error } // save the job @@ -184,16 +196,17 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { job.ImgURL = imgURL job.Progress = data.Progress job.CreatedAt = time.Now() - res = h.db.Create(&job) + res = tx.Create(&job) if res.Error != nil { - logger.Error("error with save MidJourney Job: ", res.Error) + tx.Rollback() + return res.Error } + tx.Commit() } if wsClient == nil { // 客户端断线,则丢弃 logger.Errorf("Client is offline: %+v", data) - resp.SUCCESS(c, "Client is offline") - return + return nil } if data.Status == Finished { @@ -202,22 +215,17 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) { // delete client h.App.MjTaskClients.Delete(task.Id) } else { - //// 使用代理临时转发图片 - //if data.Image.URL != "" { - // image, err := utils.DownloadImage(data.Image.URL, h.App.Config.ProxyURL) - // if err == nil { - // data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) - // } - //} data.Image.URL = fmt.Sprintf("/api/mj/proxy?url=%s", data.Image.URL) utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) } - resp.SUCCESS(c, "SUCCESS") } + return nil } func (h *MidJourneyHandler) Proxy(c *gin.Context) { + logger.Info(c.Request.Host, c.Request.Proto) + return url := c.Query("url") image, err := utils.DownloadImage(url, h.App.Config.ProxyURL) if err != nil { @@ -232,16 +240,16 @@ type reqVo struct { MessageId string `json:"message_id"` MessageHash string `json:"message_hash"` SessionId string `json:"session_id"` - Key string `json:"key"` Prompt string `json:"prompt"` + ChatId string `json:"chat_id"` + RoleId int `json:"role_id"` + Icon string `json:"icon"` } // Upscale send upscale command to MidJourney Bot func (h *MidJourneyHandler) Upscale(c *gin.Context) { var data reqVo - if err := c.ShouldBindJSON(&data); err != nil || - data.SessionId == "" || - data.Key == "" { + if err := c.ShouldBindJSON(&data); err != nil || data.SessionId == "" { resp.ERROR(c, types.InvalidArgs) return } @@ -250,35 +258,32 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { resp.ERROR(c, "No Websocket client online") return } - + userId, _ := c.Get(types.LoginUserID) h.mjService.PushTask(service.MjTask{ + Id: data.SessionId, + Src: service.TaskSrcChat, + Type: service.Upscale, + Prompt: data.Prompt, + UserId: utils.IntValue(utils.InterfaceToString(userId), 0), + RoleId: data.RoleId, + Icon: data.Icon, + ChatId: data.ChatId, Index: data.Index, MessageId: data.MessageId, MessageHash: data.MessageHash, }) - err := n.Upscale(function.MjUpscaleReq{ - Index: data.Index, - MessageId: data.MessageId, - MessageHash: data.MessageHash, - }) - if err != nil { - resp.ERROR(c, err.Error()) - return - } content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) utils.ReplyMessage(wsClient, content) - if h.App.MjTaskClients.Get(data.Key) == nil { - h.App.MjTaskClients.Put(data.Key, wsClient) + if h.App.MjTaskClients.Get(data.SessionId) == nil { + h.App.MjTaskClients.Put(data.SessionId, wsClient) } resp.SUCCESS(c) } func (h *MidJourneyHandler) Variation(c *gin.Context) { var data reqVo - if err := c.ShouldBindJSON(&data); err != nil || - data.SessionId == "" || - data.Key == "" { + if err := c.ShouldBindJSON(&data); err != nil || data.SessionId == "" { resp.ERROR(c, types.InvalidArgs) return } @@ -288,19 +293,24 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { return } - err := h.mjFunc.Variation(function.MjVariationReq{ + userId, _ := c.Get(types.LoginUserID) + h.mjService.PushTask(service.MjTask{ + Id: data.SessionId, + Src: service.TaskSrcChat, + Type: service.Variation, + Prompt: data.Prompt, + UserId: utils.IntValue(utils.InterfaceToString(userId), 0), + RoleId: data.RoleId, + Icon: data.Icon, + ChatId: data.ChatId, Index: data.Index, MessageId: data.MessageId, MessageHash: data.MessageHash, }) - if err != nil { - resp.ERROR(c, err.Error()) - return - } content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) utils.ReplyMessage(wsClient, content) - if h.App.MjTaskClients.Get(data.Key) == nil { - h.App.MjTaskClients.Put(data.Key, wsClient) + if h.App.MjTaskClients.Get(data.SessionId) == nil { + h.App.MjTaskClients.Put(data.SessionId, wsClient) } resp.SUCCESS(c) } diff --git a/api/main.go b/api/main.go index f04cccd5..4d6738ec 100644 --- a/api/main.go +++ b/api/main.go @@ -136,7 +136,7 @@ func main() { }), fx.Provide(oss.NewUploaderManager), fx.Provide(service.NewMjService), - fx.Provide(func(mjService *service.MjService) { + fx.Invoke(func(mjService *service.MjService) { go func() { mjService.Run() }() diff --git a/api/service/function/func_mj.go b/api/service/function/func_mj.go index 79cc7313..71e49b55 100644 --- a/api/service/function/func_mj.go +++ b/api/service/function/func_mj.go @@ -48,6 +48,7 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { f.service.PushTask(service.MjTask{ Id: utils.InterfaceToString(params["session_id"]), Src: service.TaskSrcChat, + Type: service.Image, Prompt: prompt, UserId: utils.IntValue(utils.InterfaceToString(params["user_id"]), 0), RoleId: utils.IntValue(utils.InterfaceToString(params["role_id"]), 0), diff --git a/api/service/function/function.go b/api/service/function/function.go index d490ac1c..5b558feb 100644 --- a/api/service/function/function.go +++ b/api/service/function/function.go @@ -3,6 +3,7 @@ package function import ( "chatplus/core/types" logger2 "chatplus/logger" + "chatplus/service" ) type Function interface { @@ -28,11 +29,11 @@ type dataItem struct { Remark string `json:"remark"` } -func NewFunctions(config *types.AppConfig) map[string]Function { +func NewFunctions(config *types.AppConfig, mjService *service.MjService) map[string]Function { return map[string]Function{ types.FuncZaoBao: NewZaoBao(config.ApiConfig), types.FuncWeibo: NewWeiboHot(config.ApiConfig), types.FuncHeadLine: NewHeadLines(config.ApiConfig), - types.FuncMidJourney: NewMidJourneyFunc(config.ExtConfig), + types.FuncMidJourney: NewMidJourneyFunc(mjService), } } diff --git a/api/service/mj_service.go b/api/service/mj_service.go index ccd70ed9..b0295168 100644 --- a/api/service/mj_service.go +++ b/api/service/mj_service.go @@ -56,19 +56,21 @@ type MjService struct { redis *redis.Client } -func NewMjService(config types.ChatPlusExtConfig, client *redis.Client) *MjService { +func NewMjService(appConfig *types.AppConfig, client *redis.Client) *MjService { return &MjService{ - config: config, + config: appConfig.ExtConfig, redis: client, taskQueue: store.NewRedisQueue("midjourney_task_queue", client), client: req.C().SetTimeout(30 * time.Second)} } func (s *MjService) Run() { + logger.Info("Starting MidJourney job consumer.") ctx := context.Background() for { - _, err := s.redis.Get(ctx, MjRunningJobKey).Result() - if err == nil { // a task is running, waiting for finish + t, err := s.redis.Get(ctx, MjRunningJobKey).Result() + if err == nil { + logger.Infof("An task is not finished: %s", t) time.Sleep(time.Second * 3) continue } @@ -78,7 +80,7 @@ func (s *MjService) Run() { logger.Errorf("taking task with error: %v", err) continue } - + logger.Infof("Consuming Task: %+v", task) switch task.Type { case Image: err = s.image(task.Prompt) @@ -98,11 +100,11 @@ func (s *MjService) Run() { }) } if err != nil { + logger.Error("绘画任务执行失败:", err) if task.RetryCount > 5 { continue } task.RetryCount += 1 - time.Sleep(time.Second) s.taskQueue.RPush(task) // TODO: 执行失败通知聊天客户端 continue @@ -114,6 +116,7 @@ func (s *MjService) Run() { } func (s *MjService) PushTask(task MjTask) { + logger.Infof("add a new MidJourney Task: %+v", task) s.taskQueue.RPush(task) } diff --git a/api/test/test.go b/api/test/test.go index 026a54da..5d5bdf6f 100644 --- a/api/test/test.go +++ b/api/test/test.go @@ -1,199 +1,12 @@ package main import ( - "bufio" - "chatplus/core" - "chatplus/core/types" - "chatplus/service/oss" - "chatplus/utils" - "context" - "encoding/json" "fmt" - "github.com/lionsoul2014/ip2region/binding/golang/xdb" - "github.com/pkoukk/tiktoken-go" - "io" - "log" - "net/http" - "os" "path/filepath" - "strings" - "time" ) func main() { - imageURL := "https://cdn.discordapp.com/attachments/1139552247693443184/1141619433752768572/lisamiller4099_A_beautiful_fairy_sister_from_Chinese_mythology__3162726e-5ee4-4f60-932b-6b78b375eaef.png" + imageURL := "https://cdn.discordapp.com/attachments/1151037077308325901/1151286701717733416/jiangjin_a_chrysanthemum_in_the_style_of_Van_Gogh_49b64011-6581-469d-9888-c285ab964e08.png" fmt.Println(filepath.Ext(filepath.Base(imageURL))) } - -// Http client 取消操作 -func testHttpClient(ctx context.Context) { - - req, err := http.NewRequest("GET", "http://localhost:2345", nil) - if err != nil { - fmt.Println(err) - return - } - - req = req.WithContext(ctx) - - client := &http.Client{} - - resp, err := client.Do(req) - if err != nil { - fmt.Println(err) - return - } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - - } - }(resp.Body) - _, err = io.ReadAll(resp.Body) - for { - time.Sleep(time.Second) - fmt.Println(time.Now()) - select { - case <-ctx.Done(): - fmt.Println("取消退出") - return - default: - continue - } - } - -} - -func testDate() { - fmt.Println(time.Unix(1683336167, 0).Format("2006-01-02 15:04:05")) -} - -func testIp2Region() { - dbPath := "res/ip2region.xdb" - // 1、从 dbPath 加载整个 xdb 到内存 - cBuff, err := xdb.LoadContentFromFile(dbPath) - if err != nil { - fmt.Printf("failed to load content from `%s`: %s\n", dbPath, err) - return - } - - // 2、用全局的 cBuff 创建完全基于内存的查询对象。 - searcher, err := xdb.NewWithBuffer(cBuff) - if err != nil { - fmt.Printf("failed to create searcher with content: %s\n", err) - return - } - - str, err := searcher.SearchByStr("103.88.46.85") - fmt.Println(str) - if err != nil { - log.Fatal(err) - } - arr := strings.Split(str, "|") - fmt.Println(arr[2], arr[3], arr[4]) - -} - -func calTokens() { - text := "须知少年凌云志,曾许人间第一流" - encoding := "cl100k_base" - - tke, err := tiktoken.GetEncoding(encoding) - if err != nil { - err = fmt.Errorf("getEncoding: %v", err) - return - } - - // encode - token := tke.Encode(text, nil, nil) - - //tokens - fmt.Println(token) - // num_tokens - fmt.Println(len(token)) - -} - -func testAesEncrypt() { - // 加密 - text := []byte("this is a secret text") - key := utils.RandString(24) - encrypt, err := utils.AesEncrypt(key, text) - if err != nil { - panic(err) - } - fmt.Println("加密密文:", encrypt) - // 解密 - decrypt, err := utils.AesDecrypt(key, encrypt) - if err != nil { - panic(err) - } - fmt.Println("解密明文:", string(decrypt)) -} - -func extractFunction() error { - open, err := os.Open("res/data.txt") - if err != nil { - return err - } - reader := bufio.NewReader(open) - var contents = make([]string, 0) - var functionCall = false - var functionName string - for { - line, err := reader.ReadString('\n') - if err != nil { - break - } - if !strings.Contains(line, "data:") { - continue - } - - var responseBody = types.ApiResponse{} - err = json.Unmarshal([]byte(line[6:]), &responseBody) - if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错 - break - } - - function := responseBody.Choices[0].Delta.FunctionCall - if functionCall && function.Name == "" { - contents = append(contents, function.Arguments) - continue - } - - if !utils.IsEmptyValue(function) { - functionCall = true - functionName = function.Name - continue - } - } - - fmt.Println("函数名称: ", functionName) - fmt.Println(strings.Join(contents, "")) - return err -} - -func minio() { - config := core.NewDefaultConfig() - config.ProxyURL = "http://localhost:7777" - config.OSS.Minio = types.MinioConfig{ - Endpoint: "localhost:9010", - AccessKey: "ObWIEyXaQUHOYU26L0oI", - AccessSecret: "AJW3HHhlGrprfPcmiC7jSOSzVCyrlhX4AnOAUzqI", - Bucket: "chatgpt-plus", - UseSSL: false, - Domain: "http://localhost:9010", - } - minioService, err := oss.NewMinioService(config) - if err != nil { - panic(err) - } - - url, err := minioService.PutImg("https://cdn.discordapp.com/attachments/1139552247693443184/1141619433752768572/lisamiller4099_A_beautiful_fairy_sister_from_Chinese_mythology__3162726e-5ee4-4f60-932b-6b78b375eaef.png") - if err != nil { - panic(err) - } - - fmt.Println(url) -} diff --git a/database/update-3.1.3.sql b/database/update-v3.1.3.sql similarity index 72% rename from database/update-3.1.3.sql rename to database/update-v3.1.3.sql index 9871fcfa..984020cc 100644 --- a/database/update-3.1.3.sql +++ b/database/update-v3.1.3.sql @@ -1,4 +1,7 @@ ALTER TABLE `chatgpt_mj_jobs` DROP `image`; +ALTER TABLE `chatgpt_mj_jobs` DROP `hash`; +ALTER TABLE `chatgpt_mj_jobs` DROP `content`; +ALTER TABLE `chatgpt_mj_jobs` DROP `chat_id`; ALTER TABLE `chatgpt_mj_jobs` ADD `progress` SMALLINT(5) NULL DEFAULT '0' COMMENT '任务进度' AFTER `prompt`; ALTER TABLE `chatgpt_mj_jobs` ADD `hash` VARCHAR(100) NULL DEFAULT NULL COMMENT 'message hash' AFTER `prompt`; ALTER TABLE `chatgpt_mj_jobs` ADD `img_url` VARCHAR(255) NULL DEFAULT NULL COMMENT '图片URL' AFTER `prompt`; \ No newline at end of file diff --git a/web/src/components/ChatMidJourney.vue b/web/src/components/ChatMidJourney.vue index a62e03ca..75928df8 100644 --- a/web/src/components/ChatMidJourney.vue +++ b/web/src/components/ChatMidJourney.vue @@ -70,6 +70,8 @@ import {getSessionId} from "@/store/session"; const props = defineProps({ content: Object, icon: String, + chatId: String, + roleId: Number, createdAt: String }); @@ -110,6 +112,9 @@ const send = (url, index) => { message_hash: data.value?.["image"]?.hash, session_id: getSessionId(), prompt: data.value?.["prompt"], + chat_id: props.chatId, + role_id: props.roleId, + icon: props.icon, }).then(() => { ElMessage.success("任务推送成功,请耐心等待任务执行...") loading.value = false diff --git a/web/src/components/ChatPrompt.vue b/web/src/components/ChatPrompt.vue index 4bc07991..0bcc2eee 100644 --- a/web/src/components/ChatPrompt.vue +++ b/web/src/components/ChatPrompt.vue @@ -44,7 +44,7 @@ export default defineComponent({ default: 0, }, model: { - type: Number, + type: String, default: '', }, }, diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index e8dc7d88..cc8e2bae 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -157,7 +157,7 @@ :icon="item.icon" :created-at="dateFormat(item['created_at'])" :tokens="item['tokens']" - :model="getModelValue(modelID.value)" + :model="getModelValue(modelID)" :content="item.content"/> { httpGet(`/api/role/list?user_id=${user.id}`).then((res) => { roles.value = res.data; roleId.value = roles.value[0]['id']; - const chatId = router.currentRoute.value.params['id'] + const chatId = localStorage.getItem("chat_id") const chat = getChatById(chatId) if (chat === null) { // 创建新的对话 newChat(); } else { // 加载对话 - loadChat(chat) + changeChat(chat) } }).catch((e) => { ElMessage.error('获取聊天角色失败: ' + e.messages) @@ -353,16 +355,16 @@ onMounted(() => { // TODO: 增加重试按钮 ElMessage.error("加载会话列表失败!") }) + + httpGet("/api/admin/config/get?key=system").then(res => { + title.value = res.data.title + }).catch(e => { + ElMessage.error("获取系统配置失败:" + e.message) + }) }).catch(() => { router.push('/login') }); - httpGet("/api/admin/config/get?key=system").then(res => { - title.value = res.data.title - }).catch(e => { - ElMessage.error("获取系统配置失败:" + e.message) - }) - const clipboard = new Clipboard('.copy-reply'); clipboard.on('success', () => { ElMessage.success('复制成功!'); @@ -422,6 +424,7 @@ const newChat = function () { // 切换会话 const changeChat = (chat) => { router.push("/chat/" + chat.chat_id) + localStorage.setItem("chat_id", chat.chat_id) loadChat(chat) } @@ -533,6 +536,8 @@ const connect = function (chat_id, role_id) { if (isNewChat) { // 加载打招呼信息 loading.value = false; chatData.value.push({ + chat_id: chat_id, + role_id: role_id, type: "reply", id: randString(32), icon: _role['icon'],