mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	feat: chat with file function is ready
This commit is contained in:
		@@ -323,20 +323,46 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
		reqMgs = append(reqMgs, m)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fullPrompt := prompt
 | 
			
		||||
	text := prompt
 | 
			
		||||
	// extract files in prompt
 | 
			
		||||
	files := utils.ExtractFileURLs(prompt)
 | 
			
		||||
	logger.Debugf("detected FILES: %+v", files)
 | 
			
		||||
	if len(files) > 0 {
 | 
			
		||||
		contents := make([]string, 0)
 | 
			
		||||
		var file model.File
 | 
			
		||||
		for _, v := range files {
 | 
			
		||||
			h.DB.Where("url = ?", v).First(&file)
 | 
			
		||||
			content, err := utils.ReadFileContent(v)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				contents = append(contents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
 | 
			
		||||
			}
 | 
			
		||||
			text = strings.Replace(text, v, "", 1)
 | 
			
		||||
		}
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			fullPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML):\n\n %s\n\n 问题:%s", strings.Join(contents, "\n"), text)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		tokens, _ := utils.CalcTokens(fullPrompt, req.Model)
 | 
			
		||||
		if tokens > session.Model.MaxContext {
 | 
			
		||||
			return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	logger.Debug("最终Prompt:", fullPrompt)
 | 
			
		||||
 | 
			
		||||
	if session.Model.Platform == types.QWen.Value {
 | 
			
		||||
		req.Input = make(map[string]interface{})
 | 
			
		||||
		reqMgs = append(reqMgs, types.Message{
 | 
			
		||||
			Role:    "user",
 | 
			
		||||
			Content: prompt,
 | 
			
		||||
			Content: fullPrompt,
 | 
			
		||||
		})
 | 
			
		||||
		req.Input["messages"] = reqMgs
 | 
			
		||||
	} else if session.Model.Platform == types.OpenAI.Value || session.Model.Platform == types.Azure.Value { // extract image for gpt-vision model
 | 
			
		||||
		imgURLs := utils.ExtractImgURL(prompt)
 | 
			
		||||
		imgURLs := utils.ExtractImgURLs(prompt)
 | 
			
		||||
		logger.Debugf("detected IMG: %+v", imgURLs)
 | 
			
		||||
		var content interface{}
 | 
			
		||||
		if len(imgURLs) > 0 {
 | 
			
		||||
			data := make([]interface{}, 0)
 | 
			
		||||
			text := prompt
 | 
			
		||||
			for _, v := range imgURLs {
 | 
			
		||||
				text = strings.Replace(text, v, "", 1)
 | 
			
		||||
				data = append(data, gin.H{
 | 
			
		||||
@@ -352,7 +378,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
			})
 | 
			
		||||
			content = data
 | 
			
		||||
		} else {
 | 
			
		||||
			content = prompt
 | 
			
		||||
			content = fullPrompt
 | 
			
		||||
		}
 | 
			
		||||
		req.Messages = append(reqMgs, map[string]interface{}{
 | 
			
		||||
			"role":    "user",
 | 
			
		||||
@@ -361,7 +387,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
	} else {
 | 
			
		||||
		req.Messages = append(reqMgs, map[string]interface{}{
 | 
			
		||||
			"role":    "user",
 | 
			
		||||
			"content": prompt,
 | 
			
		||||
			"content": fullPrompt,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -454,7 +480,7 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
 | 
			
		||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
 | 
			
		||||
	// if the chat model bind a KEY, use it directly
 | 
			
		||||
	if session.Model.KeyId > 0 {
 | 
			
		||||
		h.DB.Debug().Where("id", session.Model.KeyId).Where("enabled", true).Find(apiKey)
 | 
			
		||||
		h.DB.Where("id", session.Model.KeyId).Find(apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	// use the last unused key
 | 
			
		||||
	if apiKey.Id == 0 {
 | 
			
		||||
 
 | 
			
		||||
@@ -79,7 +79,7 @@ func (h *ChatHandler) sendXunFeiMessage(
 | 
			
		||||
	var res *gorm.DB
 | 
			
		||||
	// use the bind key
 | 
			
		||||
	if session.Model.KeyId > 0 {
 | 
			
		||||
		res = h.DB.Where("id", session.Model.KeyId).Where("enabled", true).Find(&apiKey)
 | 
			
		||||
		res = h.DB.Where("id", session.Model.KeyId).Find(&apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	// use the last unused key
 | 
			
		||||
	if apiKey.Id == 0 {
 | 
			
		||||
 
 | 
			
		||||
@@ -215,7 +215,7 @@ func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatMod
 | 
			
		||||
	// if the chat model bind a KEY, use it directly
 | 
			
		||||
	var res *gorm.DB
 | 
			
		||||
	if chatModel.KeyId > 0 {
 | 
			
		||||
		res = h.DB.Where("id", chatModel.KeyId).Where("enabled", true).Find(apiKey)
 | 
			
		||||
		res = h.DB.Where("id", chatModel.KeyId).Find(apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	// use the last unused key
 | 
			
		||||
	if apiKey.Id == 0 {
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,7 @@ import (
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	file := "http://nk.img.r9it.com/chatgpt-plus/1719389335351828.xlsx"
 | 
			
		||||
	content, err := utils.ReadPdf(file)
 | 
			
		||||
	content, err := utils.ReadFileContent(file)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -13,7 +13,8 @@ import (
 | 
			
		||||
	"github.com/google/go-tika/tika"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ReadPdf(filePath string) (string, error) {
 | 
			
		||||
func ReadFileContent(filePath string) (string, error) {
 | 
			
		||||
	// for remote file, download it first
 | 
			
		||||
	if strings.HasPrefix(filePath, "http") {
 | 
			
		||||
		file, err := downloadFile(filePath)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -31,22 +32,34 @@ func ReadPdf(filePath string) (string, error) {
 | 
			
		||||
	defer file.Close()
 | 
			
		||||
 | 
			
		||||
	// 使用 Tika 提取 PDF 文件的文本内容
 | 
			
		||||
	html, err := client.Parse(context.TODO(), file)
 | 
			
		||||
	content, err := client.Parse(context.TODO(), file)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with parse file: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fmt.Println(html)
 | 
			
		||||
 | 
			
		||||
	return cleanBlankLine(html), nil
 | 
			
		||||
	ext := filepath.Ext(filePath)
 | 
			
		||||
	switch ext {
 | 
			
		||||
	case ".doc", ".docx", ".pdf", ".pptx", "ppt":
 | 
			
		||||
		return cleanBlankLine(cleanHtml(content, false)), nil
 | 
			
		||||
	case ".xls", ".xlsx":
 | 
			
		||||
		return cleanBlankLine(cleanHtml(content, true)), nil
 | 
			
		||||
	default:
 | 
			
		||||
		return cleanBlankLine(content), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 清理文本内容
 | 
			
		||||
func cleanHtml(html string) string {
 | 
			
		||||
func cleanHtml(html string, keepTable bool) string {
 | 
			
		||||
	// 清理 HTML 标签
 | 
			
		||||
	p := bluemonday.StrictPolicy()
 | 
			
		||||
	return p.Sanitize(html)
 | 
			
		||||
	var policy *bluemonday.Policy
 | 
			
		||||
	if keepTable {
 | 
			
		||||
		policy = bluemonday.NewPolicy()
 | 
			
		||||
		policy.AllowElements("table", "thead", "tbody", "tfoot", "tr", "td", "th")
 | 
			
		||||
	} else {
 | 
			
		||||
		policy = bluemonday.StrictPolicy()
 | 
			
		||||
	}
 | 
			
		||||
	return policy.Sanitize(html)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func cleanBlankLine(content string) string {
 | 
			
		||||
@@ -57,6 +70,12 @@ func cleanBlankLine(content string) string {
 | 
			
		||||
		if len(line) < 2 {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		// discard image
 | 
			
		||||
		if strings.HasSuffix(line, ".png") ||
 | 
			
		||||
			strings.HasSuffix(line, ".jpg") ||
 | 
			
		||||
			strings.HasSuffix(line, ".jpeg") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		texts = append(texts, line)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -88,7 +88,7 @@ func GetImgExt(filename string) string {
 | 
			
		||||
	return ext
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ExtractImgURL(text string) []string {
 | 
			
		||||
func ExtractImgURLs(text string) []string {
 | 
			
		||||
	re := regexp.MustCompile(`(http[s]?:\/\/.*?\.(?:png|jpg|jpeg|gif))`)
 | 
			
		||||
	matches := re.FindAllStringSubmatch(text, 10)
 | 
			
		||||
	urls := make([]string, 0)
 | 
			
		||||
@@ -99,3 +99,15 @@ func ExtractImgURL(text string) []string {
 | 
			
		||||
	}
 | 
			
		||||
	return urls
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ExtractFileURLs(text string) []string {
 | 
			
		||||
	re := regexp.MustCompile(`(http[s]?:\/\/.*?\.(?:docx?|pdf|pptx?|xlsx?|txt))`)
 | 
			
		||||
	matches := re.FindAllStringSubmatch(text, 10)
 | 
			
		||||
	urls := make([]string, 0)
 | 
			
		||||
	if len(matches) > 0 {
 | 
			
		||||
		for _, m := range matches {
 | 
			
		||||
			urls = append(urls, m[1])
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return urls
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user