feat: chat with file function is ready

This commit is contained in:
RockYang
2024-06-27 18:01:49 +08:00
parent 3fdcc895ed
commit a27ce36a32
14 changed files with 329 additions and 75 deletions

View File

@@ -323,20 +323,46 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
reqMgs = append(reqMgs, m)
}
fullPrompt := prompt
text := prompt
// extract files in prompt
files := utils.ExtractFileURLs(prompt)
logger.Debugf("detected FILES: %+v", files)
if len(files) > 0 {
contents := make([]string, 0)
var file model.File
for _, v := range files {
h.DB.Where("url = ?", v).First(&file)
content, err := utils.ReadFileContent(v)
if err == nil {
contents = append(contents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
}
text = strings.Replace(text, v, "", 1)
}
if len(contents) > 0 {
fullPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML)\n\n %s\n\n 问题:%s", strings.Join(contents, "\n"), text)
}
tokens, _ := utils.CalcTokens(fullPrompt, req.Model)
if tokens > session.Model.MaxContext {
return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。")
}
}
logger.Debug("最终Prompt", fullPrompt)
if session.Model.Platform == types.QWen.Value {
req.Input = make(map[string]interface{})
reqMgs = append(reqMgs, types.Message{
Role: "user",
Content: prompt,
Content: fullPrompt,
})
req.Input["messages"] = reqMgs
} else if session.Model.Platform == types.OpenAI.Value || session.Model.Platform == types.Azure.Value { // extract image for gpt-vision model
imgURLs := utils.ExtractImgURL(prompt)
imgURLs := utils.ExtractImgURLs(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}
if len(imgURLs) > 0 {
data := make([]interface{}, 0)
text := prompt
for _, v := range imgURLs {
text = strings.Replace(text, v, "", 1)
data = append(data, gin.H{
@@ -352,7 +378,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
})
content = data
} else {
content = prompt
content = fullPrompt
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
@@ -361,7 +387,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
} else {
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": prompt,
"content": fullPrompt,
})
}
@@ -454,7 +480,7 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
// if the chat model bind a KEY, use it directly
if session.Model.KeyId > 0 {
h.DB.Debug().Where("id", session.Model.KeyId).Where("enabled", true).Find(apiKey)
h.DB.Where("id", session.Model.KeyId).Find(apiKey)
}
// use the last unused key
if apiKey.Id == 0 {

View File

@@ -79,7 +79,7 @@ func (h *ChatHandler) sendXunFeiMessage(
var res *gorm.DB
// use the bind key
if session.Model.KeyId > 0 {
res = h.DB.Where("id", session.Model.KeyId).Where("enabled", true).Find(&apiKey)
res = h.DB.Where("id", session.Model.KeyId).Find(&apiKey)
}
// use the last unused key
if apiKey.Id == 0 {

View File

@@ -215,7 +215,7 @@ func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatMod
// if the chat model bind a KEY, use it directly
var res *gorm.DB
if chatModel.KeyId > 0 {
res = h.DB.Where("id", chatModel.KeyId).Where("enabled", true).Find(apiKey)
res = h.DB.Where("id", chatModel.KeyId).Find(apiKey)
}
// use the last unused key
if apiKey.Id == 0 {