From 6998dd7af409a76868c417b7782c0e193c60b9e9 Mon Sep 17 00:00:00 2001 From: RockYang Date: Thu, 27 Jun 2024 18:01:49 +0800 Subject: [PATCH] feat: chat with file function is ready --- api/handler/chatimpl/chat_handler.go | 38 +++- api/handler/chatimpl/xunfei_handler.go | 2 +- api/handler/markmap_handler.go | 2 +- api/test/test.go | 2 +- api/utils/file.go | 35 +++- api/utils/upload.go | 14 +- web/babel.config.js | 2 +- web/src/assets/css/markdown/vue.css | 237 +++++++++++++++++++++++++ web/src/components/ChatPrompt.vue | 10 +- web/src/components/ChatReply.vue | 14 +- web/src/components/FileSelect.vue | 1 + web/src/utils/libs.js | 33 +--- web/src/views/ChatPlus.vue | 10 +- web/src/views/admin/ChatModel.vue | 4 +- 14 files changed, 329 insertions(+), 75 deletions(-) create mode 100644 web/src/assets/css/markdown/vue.css diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index a99c2897..98c2eefe 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -323,20 +323,46 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio reqMgs = append(reqMgs, m) } + fullPrompt := prompt + text := prompt + // extract files in prompt + files := utils.ExtractFileURLs(prompt) + logger.Debugf("detected FILES: %+v", files) + if len(files) > 0 { + contents := make([]string, 0) + var file model.File + for _, v := range files { + h.DB.Where("url = ?", v).First(&file) + content, err := utils.ReadFileContent(v) + if err == nil { + contents = append(contents, fmt.Sprintf("%s 文件内容:%s", file.Name, content)) + } + text = strings.Replace(text, v, "", 1) + } + if len(contents) > 0 { + fullPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML):\n\n %s\n\n 问题:%s", strings.Join(contents, "\n"), text) + } + + tokens, _ := utils.CalcTokens(fullPrompt, req.Model) + if tokens > session.Model.MaxContext { + return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。") + } + } + logger.Debug("最终Prompt:", fullPrompt) + if session.Model.Platform == types.QWen.Value { req.Input = make(map[string]interface{}) reqMgs = append(reqMgs, types.Message{ Role: "user", - Content: prompt, + Content: fullPrompt, }) req.Input["messages"] = reqMgs } else if session.Model.Platform == types.OpenAI.Value || session.Model.Platform == types.Azure.Value { // extract image for gpt-vision model - imgURLs := utils.ExtractImgURL(prompt) + imgURLs := utils.ExtractImgURLs(prompt) logger.Debugf("detected IMG: %+v", imgURLs) var content interface{} if len(imgURLs) > 0 { data := make([]interface{}, 0) - text := prompt for _, v := range imgURLs { text = strings.Replace(text, v, "", 1) data = append(data, gin.H{ @@ -352,7 +378,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio }) content = data } else { - content = prompt + content = fullPrompt } req.Messages = append(reqMgs, map[string]interface{}{ "role": "user", @@ -361,7 +387,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio } else { req.Messages = append(reqMgs, map[string]interface{}{ "role": "user", - "content": prompt, + "content": fullPrompt, }) } @@ -454,7 +480,7 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) { func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) { // if the chat model bind a KEY, use it directly if session.Model.KeyId > 0 { - h.DB.Debug().Where("id", session.Model.KeyId).Where("enabled", true).Find(apiKey) + h.DB.Where("id", session.Model.KeyId).Find(apiKey) } // use the last unused key if apiKey.Id == 0 { diff --git a/api/handler/chatimpl/xunfei_handler.go b/api/handler/chatimpl/xunfei_handler.go index e4a081fc..7ebea8d5 100644 --- a/api/handler/chatimpl/xunfei_handler.go +++ b/api/handler/chatimpl/xunfei_handler.go @@ -79,7 +79,7 @@ func (h *ChatHandler) sendXunFeiMessage( var res *gorm.DB // use the bind key if session.Model.KeyId > 0 { - res = h.DB.Where("id", session.Model.KeyId).Where("enabled", true).Find(&apiKey) + res = h.DB.Where("id", session.Model.KeyId).Find(&apiKey) } // use the last unused key if apiKey.Id == 0 { diff --git a/api/handler/markmap_handler.go b/api/handler/markmap_handler.go index d6565444..368d12f7 100644 --- a/api/handler/markmap_handler.go +++ b/api/handler/markmap_handler.go @@ -215,7 +215,7 @@ func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatMod // if the chat model bind a KEY, use it directly var res *gorm.DB if chatModel.KeyId > 0 { - res = h.DB.Where("id", chatModel.KeyId).Where("enabled", true).Find(apiKey) + res = h.DB.Where("id", chatModel.KeyId).Find(apiKey) } // use the last unused key if apiKey.Id == 0 { diff --git a/api/test/test.go b/api/test/test.go index b50418a0..3fae23e1 100644 --- a/api/test/test.go +++ b/api/test/test.go @@ -7,7 +7,7 @@ import ( func main() { file := "http://nk.img.r9it.com/chatgpt-plus/1719389335351828.xlsx" - content, err := utils.ReadPdf(file) + content, err := utils.ReadFileContent(file) if err != nil { panic(err) } diff --git a/api/utils/file.go b/api/utils/file.go index 220d61a4..69534462 100644 --- a/api/utils/file.go +++ b/api/utils/file.go @@ -13,7 +13,8 @@ import ( "github.com/google/go-tika/tika" ) -func ReadPdf(filePath string) (string, error) { +func ReadFileContent(filePath string) (string, error) { + // for remote file, download it first if strings.HasPrefix(filePath, "http") { file, err := downloadFile(filePath) if err != nil { @@ -31,22 +32,34 @@ func ReadPdf(filePath string) (string, error) { defer file.Close() // 使用 Tika 提取 PDF 文件的文本内容 - html, err := client.Parse(context.TODO(), file) + content, err := client.Parse(context.TODO(), file) if err != nil { return "", fmt.Errorf("error with parse file: %v", err) } - fmt.Println(html) - - return cleanBlankLine(html), nil + ext := filepath.Ext(filePath) + switch ext { + case ".doc", ".docx", ".pdf", ".pptx", "ppt": + return cleanBlankLine(cleanHtml(content, false)), nil + case ".xls", ".xlsx": + return cleanBlankLine(cleanHtml(content, true)), nil + default: + return cleanBlankLine(content), nil + } } // 清理文本内容 -func cleanHtml(html string) string { +func cleanHtml(html string, keepTable bool) string { // 清理 HTML 标签 - p := bluemonday.StrictPolicy() - return p.Sanitize(html) + var policy *bluemonday.Policy + if keepTable { + policy = bluemonday.NewPolicy() + policy.AllowElements("table", "thead", "tbody", "tfoot", "tr", "td", "th") + } else { + policy = bluemonday.StrictPolicy() + } + return policy.Sanitize(html) } func cleanBlankLine(content string) string { @@ -57,6 +70,12 @@ func cleanBlankLine(content string) string { if len(line) < 2 { continue } + // discard image + if strings.HasSuffix(line, ".png") || + strings.HasSuffix(line, ".jpg") || + strings.HasSuffix(line, ".jpeg") { + continue + } texts = append(texts, line) } diff --git a/api/utils/upload.go b/api/utils/upload.go index 5c764d8e..5227b361 100644 --- a/api/utils/upload.go +++ b/api/utils/upload.go @@ -88,7 +88,7 @@ func GetImgExt(filename string) string { return ext } -func ExtractImgURL(text string) []string { +func ExtractImgURLs(text string) []string { re := regexp.MustCompile(`(http[s]?:\/\/.*?\.(?:png|jpg|jpeg|gif))`) matches := re.FindAllStringSubmatch(text, 10) urls := make([]string, 0) @@ -99,3 +99,15 @@ func ExtractImgURL(text string) []string { } return urls } + +func ExtractFileURLs(text string) []string { + re := regexp.MustCompile(`(http[s]?:\/\/.*?\.(?:docx?|pdf|pptx?|xlsx?|txt))`) + matches := re.FindAllStringSubmatch(text, 10) + urls := make([]string, 0) + if len(matches) > 0 { + for _, m := range matches { + urls = append(urls, m[1]) + } + } + return urls +} diff --git a/web/babel.config.js b/web/babel.config.js index e9558405..7f03dcb2 100644 --- a/web/babel.config.js +++ b/web/babel.config.js @@ -1,5 +1,5 @@ module.exports = { presets: [ - '@vue/cli-plugin-babel/preset' + '@vue.css/cli-plugin-babel/preset' ] } diff --git a/web/src/assets/css/markdown/vue.css b/web/src/assets/css/markdown/vue.css new file mode 100644 index 00000000..296b9a7c --- /dev/null +++ b/web/src/assets/css/markdown/vue.css @@ -0,0 +1,237 @@ +.chat-line { + ol, ul { + margin: 0.8em 0; + list-style: normal; + } + a { + color: #42b983; + font-weight: 600; + padding: 0 2px; + text-decoration: none; + } + + h1, + h2, + h3, + h4, + h5, + h6 { + position: relative; + margin-top: 1rem; + margin-bottom: 1rem; + font-weight: bold; + line-height: 1.4; + cursor: text; + } + + h1:hover a.anchor, + h2:hover a.anchor, + h3:hover a.anchor, + h4:hover a.anchor, + h5:hover a.anchor, + h6:hover a.anchor { + text-decoration: none; + } + + h1 tt, + h1 code { + font-size: inherit !important; + } + + h2 tt, + h2 code { + font-size: inherit !important; + } + + h3 tt, + h3 code { + font-size: inherit !important; + } + + h4 tt, + h4 code { + font-size: inherit !important; + } + + h5 tt, + h5 code { + font-size: inherit !important; + } + + h6 tt, + h6 code { + font-size: inherit !important; + } + + h2 a, + h3 a { + color: #34495e; + } + + h1 { + padding-bottom: .4rem; + font-size: 2.2rem; + line-height: 1.3; + } + + h2 { + font-size: 1.75rem; + line-height: 1.225; + margin: 35px 0 15px; + padding-bottom: 0.5em; + border-bottom: 1px solid #ddd; + } + + h3 { + font-size: 1.4rem; + line-height: 1.43; + margin: 20px 0 7px; + } + + h4 { + font-size: 1.2rem; + } + + h5 { + font-size: 1rem; + } + + h6 { + font-size: 1rem; + color: #777; + } + + p, + blockquote, + ul, + ol, + dl, + table { + margin: 0.8em 0; + } + + li > ol, + li > ul { + margin: 0 0; + } + + hr { + height: 2px; + padding: 0; + margin: 16px 0; + background-color: #e7e7e7; + border: 0 none; + overflow: hidden; + box-sizing: content-box; + } + + body > h2:first-child { + margin-top: 0; + padding-top: 0; + } + + body > h1:first-child { + margin-top: 0; + padding-top: 0; + } + + body > h1:first-child + h2 { + margin-top: 0; + padding-top: 0; + } + + body > h3:first-child, + body > h4:first-child, + body > h5:first-child, + body > h6:first-child { + margin-top: 0; + padding-top: 0; + } + + a:first-child h1, + a:first-child h2, + a:first-child h3, + a:first-child h4, + a:first-child h5, + a:first-child h6 { + margin-top: 0; + padding-top: 0; + } + + h1 p, + h2 p, + h3 p, + h4 p, + h5 p, + h6 p { + margin-top: 0; + } + + li p.first { + display: inline-block; + } + + ul, + ol { + padding-left: 30px; + } + + ul:first-child, + ol:first-child { + margin-top: 0; + } + + ul:last-child, + ol:last-child { + margin-bottom: 0; + } + + blockquote { + border-left: 4px solid #42b983; + padding: 10px 15px; + color: #777; + background-color: rgba(66, 185, 131, .1); + } + + table { + padding: 0; + word-break: initial; + } + + table tr { + border-top: 1px solid #dfe2e5; + margin: 0; + padding: 0; + } + + table tr:nth-child(2n), + thead { + background-color: #fafafa; + } + + table tr th { + font-weight: bold; + border: 1px solid #dfe2e5; + border-bottom: 0; + text-align: left; + margin: 0; + padding: 6px 13px; + } + + table tr td { + border: 1px solid #dfe2e5; + text-align: left; + margin: 0; + padding: 6px 13px; + } + + table tr th:first-child, + table tr td:first-child { + margin-top: 0; + } + + table tr th:last-child, + table tr td:last-child { + margin-bottom: 0; + } +} diff --git a/web/src/components/ChatPrompt.vue b/web/src/components/ChatPrompt.vue index eda29117..e571496c 100644 --- a/web/src/components/ChatPrompt.vue +++ b/web/src/components/ChatPrompt.vue @@ -16,7 +16,9 @@
-
{{file.name}}
+
+ {{file.name}} +
{{GetFileType(file.ext)}} {{FormatFileSize(file.size)}} @@ -121,6 +123,7 @@ onMounted(() => {