mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-09 10:43:44 +08:00
feat: chat with file function is ready
This commit is contained in:
@@ -323,20 +323,46 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
reqMgs = append(reqMgs, m)
|
||||
}
|
||||
|
||||
fullPrompt := prompt
|
||||
text := prompt
|
||||
// extract files in prompt
|
||||
files := utils.ExtractFileURLs(prompt)
|
||||
logger.Debugf("detected FILES: %+v", files)
|
||||
if len(files) > 0 {
|
||||
contents := make([]string, 0)
|
||||
var file model.File
|
||||
for _, v := range files {
|
||||
h.DB.Where("url = ?", v).First(&file)
|
||||
content, err := utils.ReadFileContent(v)
|
||||
if err == nil {
|
||||
contents = append(contents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
|
||||
}
|
||||
text = strings.Replace(text, v, "", 1)
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
fullPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML):\n\n %s\n\n 问题:%s", strings.Join(contents, "\n"), text)
|
||||
}
|
||||
|
||||
tokens, _ := utils.CalcTokens(fullPrompt, req.Model)
|
||||
if tokens > session.Model.MaxContext {
|
||||
return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。")
|
||||
}
|
||||
}
|
||||
logger.Debug("最终Prompt:", fullPrompt)
|
||||
|
||||
if session.Model.Platform == types.QWen.Value {
|
||||
req.Input = make(map[string]interface{})
|
||||
reqMgs = append(reqMgs, types.Message{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
Content: fullPrompt,
|
||||
})
|
||||
req.Input["messages"] = reqMgs
|
||||
} else if session.Model.Platform == types.OpenAI.Value || session.Model.Platform == types.Azure.Value { // extract image for gpt-vision model
|
||||
imgURLs := utils.ExtractImgURL(prompt)
|
||||
imgURLs := utils.ExtractImgURLs(prompt)
|
||||
logger.Debugf("detected IMG: %+v", imgURLs)
|
||||
var content interface{}
|
||||
if len(imgURLs) > 0 {
|
||||
data := make([]interface{}, 0)
|
||||
text := prompt
|
||||
for _, v := range imgURLs {
|
||||
text = strings.Replace(text, v, "", 1)
|
||||
data = append(data, gin.H{
|
||||
@@ -352,7 +378,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
})
|
||||
content = data
|
||||
} else {
|
||||
content = prompt
|
||||
content = fullPrompt
|
||||
}
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
"role": "user",
|
||||
@@ -361,7 +387,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
} else {
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"content": fullPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -454,7 +480,7 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
// if the chat model bind a KEY, use it directly
|
||||
if session.Model.KeyId > 0 {
|
||||
h.DB.Debug().Where("id", session.Model.KeyId).Where("enabled", true).Find(apiKey)
|
||||
h.DB.Where("id", session.Model.KeyId).Find(apiKey)
|
||||
}
|
||||
// use the last unused key
|
||||
if apiKey.Id == 0 {
|
||||
|
||||
@@ -79,7 +79,7 @@ func (h *ChatHandler) sendXunFeiMessage(
|
||||
var res *gorm.DB
|
||||
// use the bind key
|
||||
if session.Model.KeyId > 0 {
|
||||
res = h.DB.Where("id", session.Model.KeyId).Where("enabled", true).Find(&apiKey)
|
||||
res = h.DB.Where("id", session.Model.KeyId).Find(&apiKey)
|
||||
}
|
||||
// use the last unused key
|
||||
if apiKey.Id == 0 {
|
||||
|
||||
@@ -215,7 +215,7 @@ func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatMod
|
||||
// if the chat model bind a KEY, use it directly
|
||||
var res *gorm.DB
|
||||
if chatModel.KeyId > 0 {
|
||||
res = h.DB.Where("id", chatModel.KeyId).Where("enabled", true).Find(apiKey)
|
||||
res = h.DB.Where("id", chatModel.KeyId).Find(apiKey)
|
||||
}
|
||||
// use the last unused key
|
||||
if apiKey.Id == 0 {
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
func main() {
|
||||
file := "http://nk.img.r9it.com/chatgpt-plus/1719389335351828.xlsx"
|
||||
content, err := utils.ReadPdf(file)
|
||||
content, err := utils.ReadFileContent(file)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,8 @@ import (
|
||||
"github.com/google/go-tika/tika"
|
||||
)
|
||||
|
||||
func ReadPdf(filePath string) (string, error) {
|
||||
func ReadFileContent(filePath string) (string, error) {
|
||||
// for remote file, download it first
|
||||
if strings.HasPrefix(filePath, "http") {
|
||||
file, err := downloadFile(filePath)
|
||||
if err != nil {
|
||||
@@ -31,22 +32,34 @@ func ReadPdf(filePath string) (string, error) {
|
||||
defer file.Close()
|
||||
|
||||
// 使用 Tika 提取 PDF 文件的文本内容
|
||||
html, err := client.Parse(context.TODO(), file)
|
||||
content, err := client.Parse(context.TODO(), file)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse file: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println(html)
|
||||
|
||||
return cleanBlankLine(html), nil
|
||||
ext := filepath.Ext(filePath)
|
||||
switch ext {
|
||||
case ".doc", ".docx", ".pdf", ".pptx", "ppt":
|
||||
return cleanBlankLine(cleanHtml(content, false)), nil
|
||||
case ".xls", ".xlsx":
|
||||
return cleanBlankLine(cleanHtml(content, true)), nil
|
||||
default:
|
||||
return cleanBlankLine(content), nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 清理文本内容
|
||||
func cleanHtml(html string) string {
|
||||
func cleanHtml(html string, keepTable bool) string {
|
||||
// 清理 HTML 标签
|
||||
p := bluemonday.StrictPolicy()
|
||||
return p.Sanitize(html)
|
||||
var policy *bluemonday.Policy
|
||||
if keepTable {
|
||||
policy = bluemonday.NewPolicy()
|
||||
policy.AllowElements("table", "thead", "tbody", "tfoot", "tr", "td", "th")
|
||||
} else {
|
||||
policy = bluemonday.StrictPolicy()
|
||||
}
|
||||
return policy.Sanitize(html)
|
||||
}
|
||||
|
||||
func cleanBlankLine(content string) string {
|
||||
@@ -57,6 +70,12 @@ func cleanBlankLine(content string) string {
|
||||
if len(line) < 2 {
|
||||
continue
|
||||
}
|
||||
// discard image
|
||||
if strings.HasSuffix(line, ".png") ||
|
||||
strings.HasSuffix(line, ".jpg") ||
|
||||
strings.HasSuffix(line, ".jpeg") {
|
||||
continue
|
||||
}
|
||||
texts = append(texts, line)
|
||||
}
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ func GetImgExt(filename string) string {
|
||||
return ext
|
||||
}
|
||||
|
||||
func ExtractImgURL(text string) []string {
|
||||
func ExtractImgURLs(text string) []string {
|
||||
re := regexp.MustCompile(`(http[s]?:\/\/.*?\.(?:png|jpg|jpeg|gif))`)
|
||||
matches := re.FindAllStringSubmatch(text, 10)
|
||||
urls := make([]string, 0)
|
||||
@@ -99,3 +99,15 @@ func ExtractImgURL(text string) []string {
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
func ExtractFileURLs(text string) []string {
|
||||
re := regexp.MustCompile(`(http[s]?:\/\/.*?\.(?:docx?|pdf|pptx?|xlsx?|txt))`)
|
||||
matches := re.FindAllStringSubmatch(text, 10)
|
||||
urls := make([]string, 0)
|
||||
if len(matches) > 0 {
|
||||
for _, m := range matches {
|
||||
urls = append(urls, m[1])
|
||||
}
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user