mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	extract code for saving chat history
This commit is contained in:
		@@ -61,15 +61,15 @@ type ChatSession struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatModel struct {
 | 
			
		||||
	Id          uint     `json:"id"`
 | 
			
		||||
	Platform    Platform `json:"platform"`
 | 
			
		||||
	Name        string   `json:"name"`
 | 
			
		||||
	Value       string   `json:"value"`
 | 
			
		||||
	Power       int      `json:"power"`
 | 
			
		||||
	MaxTokens   int      `json:"max_tokens"`  // 最大响应长度
 | 
			
		||||
	MaxContext  int      `json:"max_context"` // 最大上下文长度
 | 
			
		||||
	Temperature float32  `json:"temperature"` // 模型温度
 | 
			
		||||
	KeyId       int      `json:"key_id"`      // 绑定 API KEY
 | 
			
		||||
	Id          uint    `json:"id"`
 | 
			
		||||
	Platform    string  `json:"platform"`
 | 
			
		||||
	Name        string  `json:"name"`
 | 
			
		||||
	Value       string  `json:"value"`
 | 
			
		||||
	Power       int     `json:"power"`
 | 
			
		||||
	MaxTokens   int     `json:"max_tokens"`  // 最大响应长度
 | 
			
		||||
	MaxContext  int     `json:"max_context"` // 最大上下文长度
 | 
			
		||||
	Temperature float32 `json:"temperature"` // 模型温度
 | 
			
		||||
	KeyId       int     `json:"key_id"`      // 绑定 API KEY
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ApiError struct {
 | 
			
		||||
 
 | 
			
		||||
@@ -137,14 +137,44 @@ func (c RedisConfig) Url() string {
 | 
			
		||||
	return fmt.Sprintf("%s:%d", c.Host, c.Port)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Platform string
 | 
			
		||||
type Platform struct {
 | 
			
		||||
	Name    string `json:"name"`
 | 
			
		||||
	Value   string `json:"value"`
 | 
			
		||||
	ChatURL string `json:"chat_url"`
 | 
			
		||||
	ImgURL  string `json:"img_url"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const OpenAI = Platform("OpenAI")
 | 
			
		||||
const Azure = Platform("Azure")
 | 
			
		||||
const ChatGLM = Platform("ChatGLM")
 | 
			
		||||
const Baidu = Platform("Baidu")
 | 
			
		||||
const XunFei = Platform("XunFei")
 | 
			
		||||
const QWen = Platform("QWen")
 | 
			
		||||
var OpenAI = Platform{
 | 
			
		||||
	Name:    "OpenAI - GPT",
 | 
			
		||||
	Value:   "OpenAI",
 | 
			
		||||
	ChatURL: "https://api.chat-plus.net/v1/chat/completions",
 | 
			
		||||
	ImgURL:  "https://api.chat-plus.net/v1/images/generations",
 | 
			
		||||
}
 | 
			
		||||
var Azure = Platform{
 | 
			
		||||
	Name:    "微软 - Azure",
 | 
			
		||||
	Value:   "Azure",
 | 
			
		||||
	ChatURL: "https://chat-bot-api.openai.azure.com/openai/deployments/{model}/chat/completions?api-version=2023-05-15",
 | 
			
		||||
}
 | 
			
		||||
var ChatGLM = Platform{
 | 
			
		||||
	Name:    "智谱 - ChatGLM",
 | 
			
		||||
	Value:   "ChatGLM",
 | 
			
		||||
	ChatURL: "https://open.bigmodel.cn/api/paas/v3/model-api/{model}/sse-invoke",
 | 
			
		||||
}
 | 
			
		||||
var Baidu = Platform{
 | 
			
		||||
	Name:    "百度 - 文心大模型",
 | 
			
		||||
	Value:   "Baidu",
 | 
			
		||||
	ChatURL: "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}",
 | 
			
		||||
}
 | 
			
		||||
var XunFei = Platform{
 | 
			
		||||
	Name:    "讯飞 - 星火大模型",
 | 
			
		||||
	Value:   "XunFei",
 | 
			
		||||
	ChatURL: "wss://spark-api.xf-yun.com/{version}/chat",
 | 
			
		||||
}
 | 
			
		||||
var QWen = Platform{
 | 
			
		||||
	Name:    "阿里 - 通义千问",
 | 
			
		||||
	Value:   "QWen",
 | 
			
		||||
	ChatURL: "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SystemConfig struct {
 | 
			
		||||
	Title         string `json:"title,omitempty"`
 | 
			
		||||
 
 | 
			
		||||
@@ -28,8 +28,8 @@ type ConfigHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	levelDB        *store.LevelDB
 | 
			
		||||
	licenseService *service.LicenseService
 | 
			
		||||
	mjServicePool *mj.ServicePool
 | 
			
		||||
	sdServicePool *sd.ServicePool
 | 
			
		||||
	mjServicePool  *mj.ServicePool
 | 
			
		||||
	sdServicePool  *sd.ServicePool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, mjPool *mj.ServicePool, sdPool *sd.ServicePool) *ConfigHandler {
 | 
			
		||||
@@ -140,12 +140,13 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c, license)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetDrawingConfig 获取AI绘画配置
 | 
			
		||||
func (h *ConfigHandler) GetDrawingConfig(c *gin.Context) {
 | 
			
		||||
// GetAppConfig 获取内置配置
 | 
			
		||||
func (h *ConfigHandler) GetAppConfig(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c, gin.H{
 | 
			
		||||
		"mj_plus":  h.App.Config.MjPlusConfigs,
 | 
			
		||||
		"mj_proxy": h.App.Config.MjProxyConfigs,
 | 
			
		||||
		"sd":       h.App.Config.SdConfigs,
 | 
			
		||||
		"mj_plus":   h.App.Config.MjPlusConfigs,
 | 
			
		||||
		"mj_proxy":  h.App.Config.MjProxyConfigs,
 | 
			
		||||
		"sd":        h.App.Config.SdConfigs,
 | 
			
		||||
		"platforms": Platforms,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										12
									
								
								api/handler/admin/types.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								api/handler/admin/types.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,12 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
import "geekai/core/types"
 | 
			
		||||
 | 
			
		||||
var Platforms = []types.Platform{
 | 
			
		||||
	types.OpenAI,
 | 
			
		||||
	types.QWen,
 | 
			
		||||
	types.XunFei,
 | 
			
		||||
	types.ChatGLM,
 | 
			
		||||
	types.Baidu,
 | 
			
		||||
	types.Azure,
 | 
			
		||||
}
 | 
			
		||||
@@ -17,11 +17,9 @@ import (
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 微软 Azure 模型消息发送实现
 | 
			
		||||
@@ -101,104 +99,12 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
 | 
			
		||||
			if message.Role == "" {
 | 
			
		||||
				message.Role = "assistant"
 | 
			
		||||
			}
 | 
			
		||||
			message.Content = strings.Join(contents, "")
 | 
			
		||||
			useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
			// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
			if h.App.SysConfig.EnableContext {
 | 
			
		||||
				chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
				chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
				h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 追加聊天记录
 | 
			
		||||
			// for prompt
 | 
			
		||||
			promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.PromptMsg,
 | 
			
		||||
				Icon:       userVo.Avatar,
 | 
			
		||||
				Content:    template.HTMLEscapeString(prompt),
 | 
			
		||||
				Tokens:     promptToken,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
			historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
			res := h.DB.Save(&historyUserMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 计算本次对话消耗的总 token 数量
 | 
			
		||||
			replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
			replyTokens += getTotalTokens(req)
 | 
			
		||||
 | 
			
		||||
			historyReplyMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.ReplyMsg,
 | 
			
		||||
				Icon:       role.Icon,
 | 
			
		||||
				Content:    message.Content,
 | 
			
		||||
				Tokens:     replyTokens,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
			historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
			res = h.DB.Create(&historyReplyMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 更新用户算力
 | 
			
		||||
			h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
			
		||||
 | 
			
		||||
			// 保存当前会话
 | 
			
		||||
			var chatItem model.ChatItem
 | 
			
		||||
			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				chatItem.ChatId = session.ChatId
 | 
			
		||||
				chatItem.UserId = session.UserId
 | 
			
		||||
				chatItem.RoleId = role.Id
 | 
			
		||||
				chatItem.ModelId = session.Model.Id
 | 
			
		||||
				if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
					chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
				} else {
 | 
			
		||||
					chatItem.Title = prompt
 | 
			
		||||
				}
 | 
			
		||||
				chatItem.Model = req.Model
 | 
			
		||||
				h.DB.Create(&chatItem)
 | 
			
		||||
			}
 | 
			
		||||
			h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	} else {
 | 
			
		||||
		body, err := io.ReadAll(response.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("error with reading response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		var res types.ApiError
 | 
			
		||||
		err = json.Unmarshal(body, &res)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("error with decode response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if strings.Contains(res.Error.Message, "maximum context length") {
 | 
			
		||||
			logger.Error(res.Error.Message)
 | 
			
		||||
			h.App.ChatContexts.Delete(session.ChatId)
 | 
			
		||||
			return h.sendMessage(ctx, session, role, prompt, ws)
 | 
			
		||||
		} else {
 | 
			
		||||
			return fmt.Errorf("请求 Azure API 失败:%v", res.Error)
 | 
			
		||||
		}
 | 
			
		||||
		body, _ := io.ReadAll(response.Body)
 | 
			
		||||
		return fmt.Errorf("请求大模型 API 失败:%s", body)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
 
 | 
			
		||||
@@ -17,12 +17,10 @@ import (
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type baiduResp struct {
 | 
			
		||||
@@ -130,99 +128,11 @@ func (h *ChatHandler) sendBaiduMessage(
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			if message.Role == "" {
 | 
			
		||||
				message.Role = "assistant"
 | 
			
		||||
			}
 | 
			
		||||
			message.Content = strings.Join(contents, "")
 | 
			
		||||
			useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
			// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
			if h.App.SysConfig.EnableContext {
 | 
			
		||||
				chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
				chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
				h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 追加聊天记录
 | 
			
		||||
			// for prompt
 | 
			
		||||
			promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.PromptMsg,
 | 
			
		||||
				Icon:       userVo.Avatar,
 | 
			
		||||
				Content:    template.HTMLEscapeString(prompt),
 | 
			
		||||
				Tokens:     promptToken,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
			historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
			res := h.DB.Save(&historyUserMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// for reply
 | 
			
		||||
			// 计算本次对话消耗的总 token 数量
 | 
			
		||||
			replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
			totalTokens := replyTokens + getTotalTokens(req)
 | 
			
		||||
			historyReplyMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.ReplyMsg,
 | 
			
		||||
				Icon:       role.Icon,
 | 
			
		||||
				Content:    message.Content,
 | 
			
		||||
				Tokens:     totalTokens,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
			historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
			res = h.DB.Create(&historyReplyMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
			// 更新用户算力
 | 
			
		||||
			h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
			
		||||
 | 
			
		||||
			// 保存当前会话
 | 
			
		||||
			var chatItem model.ChatItem
 | 
			
		||||
			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				chatItem.ChatId = session.ChatId
 | 
			
		||||
				chatItem.UserId = session.UserId
 | 
			
		||||
				chatItem.RoleId = role.Id
 | 
			
		||||
				chatItem.ModelId = session.Model.Id
 | 
			
		||||
				if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
					chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
				} else {
 | 
			
		||||
					chatItem.Title = prompt
 | 
			
		||||
				}
 | 
			
		||||
				chatItem.Model = req.Model
 | 
			
		||||
				h.DB.Create(&chatItem)
 | 
			
		||||
			}
 | 
			
		||||
			h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		body, err := io.ReadAll(response.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("error with reading response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var res struct {
 | 
			
		||||
			Code int    `json:"error_code"`
 | 
			
		||||
			Msg  string `json:"error_msg"`
 | 
			
		||||
		}
 | 
			
		||||
		err = json.Unmarshal(body, &res)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("error with decode response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg)
 | 
			
		||||
		body, _ := io.ReadAll(response.Body)
 | 
			
		||||
		return fmt.Errorf("请求大模型 API 失败:%s", body)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
 
 | 
			
		||||
@@ -23,11 +23,13 @@ import (
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
@@ -122,7 +124,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
		MaxContext:  chatModel.MaxContext,
 | 
			
		||||
		Temperature: chatModel.Temperature,
 | 
			
		||||
		KeyId:       chatModel.KeyId,
 | 
			
		||||
		Platform:    types.Platform(chatModel.Platform)}
 | 
			
		||||
		Platform:    chatModel.Platform}
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
 | 
			
		||||
 | 
			
		||||
	// 保存会话连接
 | 
			
		||||
@@ -218,11 +220,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
		Stream: true,
 | 
			
		||||
	}
 | 
			
		||||
	switch session.Model.Platform {
 | 
			
		||||
	case types.Azure, types.ChatGLM, types.Baidu, types.XunFei:
 | 
			
		||||
	case types.Azure.Value, types.ChatGLM.Value, types.Baidu.Value, types.XunFei.Value:
 | 
			
		||||
		req.Temperature = session.Model.Temperature
 | 
			
		||||
		req.MaxTokens = session.Model.MaxTokens
 | 
			
		||||
		break
 | 
			
		||||
	case types.OpenAI:
 | 
			
		||||
	case types.OpenAI.Value:
 | 
			
		||||
		req.Temperature = session.Model.Temperature
 | 
			
		||||
		req.MaxTokens = session.Model.MaxTokens
 | 
			
		||||
		// OpenAI 支持函数功能
 | 
			
		||||
@@ -261,7 +263,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
			req.Tools = tools
 | 
			
		||||
			req.ToolChoice = "auto"
 | 
			
		||||
		}
 | 
			
		||||
	case types.QWen:
 | 
			
		||||
	case types.QWen.Value:
 | 
			
		||||
		req.Parameters = map[string]interface{}{
 | 
			
		||||
			"max_tokens":  session.Model.MaxTokens,
 | 
			
		||||
			"temperature": session.Model.Temperature,
 | 
			
		||||
@@ -325,14 +327,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
		reqMgs = append(reqMgs, m)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if session.Model.Platform == types.QWen {
 | 
			
		||||
	if session.Model.Platform == types.QWen.Value {
 | 
			
		||||
		req.Input = make(map[string]interface{})
 | 
			
		||||
		reqMgs = append(reqMgs, types.Message{
 | 
			
		||||
			Role:    "user",
 | 
			
		||||
			Content: prompt,
 | 
			
		||||
		})
 | 
			
		||||
		req.Input["messages"] = reqMgs
 | 
			
		||||
	} else if session.Model.Platform == types.OpenAI { // extract image for gpt-vision model
 | 
			
		||||
	} else if session.Model.Platform == types.OpenAI.Value { // extract image for gpt-vision model
 | 
			
		||||
		imgURLs := utils.ExtractImgURL(prompt)
 | 
			
		||||
		logger.Debugf("detected IMG: %+v", imgURLs)
 | 
			
		||||
		var content interface{}
 | 
			
		||||
@@ -370,17 +372,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
	logger.Debugf("%+v", req.Messages)
 | 
			
		||||
 | 
			
		||||
	switch session.Model.Platform {
 | 
			
		||||
	case types.Azure:
 | 
			
		||||
	case types.Azure.Value:
 | 
			
		||||
		return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	case types.OpenAI:
 | 
			
		||||
	case types.OpenAI.Value:
 | 
			
		||||
		return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	case types.ChatGLM:
 | 
			
		||||
	case types.ChatGLM.Value:
 | 
			
		||||
		return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	case types.Baidu:
 | 
			
		||||
	case types.Baidu.Value:
 | 
			
		||||
		return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	case types.XunFei:
 | 
			
		||||
	case types.XunFei.Value:
 | 
			
		||||
		return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	case types.QWen:
 | 
			
		||||
	case types.QWen.Value:
 | 
			
		||||
		return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -467,7 +469,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ONLY allow apiURL in blank list
 | 
			
		||||
	if session.Model.Platform == types.OpenAI {
 | 
			
		||||
	if session.Model.Platform == types.OpenAI.Value {
 | 
			
		||||
		err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
@@ -476,19 +478,19 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
 | 
			
		||||
 | 
			
		||||
	var apiURL string
 | 
			
		||||
	switch session.Model.Platform {
 | 
			
		||||
	case types.Azure:
 | 
			
		||||
	case types.Azure.Value:
 | 
			
		||||
		md := strings.Replace(req.Model, ".", "", 1)
 | 
			
		||||
		apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
 | 
			
		||||
		break
 | 
			
		||||
	case types.ChatGLM:
 | 
			
		||||
	case types.ChatGLM.Value:
 | 
			
		||||
		apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
 | 
			
		||||
		req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
 | 
			
		||||
		req.Messages = nil
 | 
			
		||||
		break
 | 
			
		||||
	case types.Baidu:
 | 
			
		||||
	case types.Baidu.Value:
 | 
			
		||||
		apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
 | 
			
		||||
		break
 | 
			
		||||
	case types.QWen:
 | 
			
		||||
	case types.QWen.Value:
 | 
			
		||||
		apiURL = apiKey.ApiURL
 | 
			
		||||
		req.Messages = nil
 | 
			
		||||
		break
 | 
			
		||||
@@ -498,7 +500,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
 | 
			
		||||
	// 更新 API KEY 的最后使用时间
 | 
			
		||||
	h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
	// 百度文心,需要串接 access_token
 | 
			
		||||
	if session.Model.Platform == types.Baidu {
 | 
			
		||||
	if session.Model.Platform == types.Baidu.Value {
 | 
			
		||||
		token, err := h.getBaiduToken(apiKey.Value)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
@@ -534,22 +536,22 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
 | 
			
		||||
	}
 | 
			
		||||
	logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
 | 
			
		||||
	switch session.Model.Platform {
 | 
			
		||||
	case types.Azure:
 | 
			
		||||
	case types.Azure.Value:
 | 
			
		||||
		request.Header.Set("api-key", apiKey.Value)
 | 
			
		||||
		break
 | 
			
		||||
	case types.ChatGLM:
 | 
			
		||||
	case types.ChatGLM.Value:
 | 
			
		||||
		token, err := h.getChatGLMToken(apiKey.Value)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
 | 
			
		||||
		break
 | 
			
		||||
	case types.Baidu:
 | 
			
		||||
	case types.Baidu.Value:
 | 
			
		||||
		request.RequestURI = ""
 | 
			
		||||
	case types.OpenAI:
 | 
			
		||||
	case types.OpenAI.Value:
 | 
			
		||||
		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
 | 
			
		||||
		break
 | 
			
		||||
	case types.QWen:
 | 
			
		||||
	case types.QWen.Value:
 | 
			
		||||
		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
 | 
			
		||||
		request.Header.Set("X-DashScope-SSE", "enable")
 | 
			
		||||
		break
 | 
			
		||||
@@ -583,6 +585,97 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) saveChatHistory(
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	prompt string,
 | 
			
		||||
	contents []string,
 | 
			
		||||
	message types.Message,
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	session *types.ChatSession,
 | 
			
		||||
	role model.ChatRole,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	promptCreatedAt time.Time,
 | 
			
		||||
	replyCreatedAt time.Time) {
 | 
			
		||||
	if message.Role == "" {
 | 
			
		||||
		message.Role = "assistant"
 | 
			
		||||
	}
 | 
			
		||||
	message.Content = strings.Join(contents, "")
 | 
			
		||||
	useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
	// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
	if h.App.SysConfig.EnableContext {
 | 
			
		||||
		chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
		chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
		h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 追加聊天记录
 | 
			
		||||
	// for prompt
 | 
			
		||||
	promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
	}
 | 
			
		||||
	historyUserMsg := model.ChatMessage{
 | 
			
		||||
		UserId:     userVo.Id,
 | 
			
		||||
		ChatId:     session.ChatId,
 | 
			
		||||
		RoleId:     role.Id,
 | 
			
		||||
		Type:       types.PromptMsg,
 | 
			
		||||
		Icon:       userVo.Avatar,
 | 
			
		||||
		Content:    template.HTMLEscapeString(prompt),
 | 
			
		||||
		Tokens:     promptToken,
 | 
			
		||||
		UseContext: true,
 | 
			
		||||
		Model:      req.Model,
 | 
			
		||||
	}
 | 
			
		||||
	historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
	historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
	res := h.DB.Save(&historyUserMsg)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// for reply
 | 
			
		||||
	// 计算本次对话消耗的总 token 数量
 | 
			
		||||
	replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
	totalTokens := replyTokens + getTotalTokens(req)
 | 
			
		||||
	historyReplyMsg := model.ChatMessage{
 | 
			
		||||
		UserId:     userVo.Id,
 | 
			
		||||
		ChatId:     session.ChatId,
 | 
			
		||||
		RoleId:     role.Id,
 | 
			
		||||
		Type:       types.ReplyMsg,
 | 
			
		||||
		Icon:       role.Icon,
 | 
			
		||||
		Content:    message.Content,
 | 
			
		||||
		Tokens:     totalTokens,
 | 
			
		||||
		UseContext: true,
 | 
			
		||||
		Model:      req.Model,
 | 
			
		||||
	}
 | 
			
		||||
	historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
	historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
	res = h.DB.Create(&historyReplyMsg)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新用户算力
 | 
			
		||||
	h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
			
		||||
 | 
			
		||||
	// 保存当前会话
 | 
			
		||||
	var chatItem model.ChatItem
 | 
			
		||||
	res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		chatItem.ChatId = session.ChatId
 | 
			
		||||
		chatItem.UserId = session.UserId
 | 
			
		||||
		chatItem.RoleId = role.Id
 | 
			
		||||
		chatItem.ModelId = session.Model.Id
 | 
			
		||||
		if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
			chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
		} else {
 | 
			
		||||
			chatItem.Title = prompt
 | 
			
		||||
		}
 | 
			
		||||
		chatItem.Model = req.Model
 | 
			
		||||
		h.DB.Create(&chatItem)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 将AI回复消息中生成的图片链接下载到本地
 | 
			
		||||
func (h *ChatHandler) extractImgUrl(text string) string {
 | 
			
		||||
	pattern := `!\[([^\]]*)]\(([^)]+)\)`
 | 
			
		||||
 
 | 
			
		||||
@@ -10,7 +10,6 @@ package chatimpl
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
@@ -18,11 +17,9 @@ import (
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/golang-jwt/jwt/v5"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 清华大学 ChatGML 消息发送实现
 | 
			
		||||
@@ -108,103 +105,11 @@ func (h *ChatHandler) sendChatGLMMessage(
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			if message.Role == "" {
 | 
			
		||||
				message.Role = "assistant"
 | 
			
		||||
			}
 | 
			
		||||
			message.Content = strings.Join(contents, "")
 | 
			
		||||
			useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
			// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
			if h.App.SysConfig.EnableContext {
 | 
			
		||||
				chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
				chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
				h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 追加聊天记录
 | 
			
		||||
			// for prompt
 | 
			
		||||
			promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.PromptMsg,
 | 
			
		||||
				Icon:       userVo.Avatar,
 | 
			
		||||
				Content:    template.HTMLEscapeString(prompt),
 | 
			
		||||
				Tokens:     promptToken,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
			historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
			res := h.DB.Save(&historyUserMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// for reply
 | 
			
		||||
			// 计算本次对话消耗的总 token 数量
 | 
			
		||||
			replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
			totalTokens := replyTokens + getTotalTokens(req)
 | 
			
		||||
			historyReplyMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.ReplyMsg,
 | 
			
		||||
				Icon:       role.Icon,
 | 
			
		||||
				Content:    message.Content,
 | 
			
		||||
				Tokens:     totalTokens,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
			historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
			res = h.DB.Create(&historyReplyMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 更新用户算力
 | 
			
		||||
			h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
			
		||||
 | 
			
		||||
			// 保存当前会话
 | 
			
		||||
			var chatItem model.ChatItem
 | 
			
		||||
			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				chatItem.ChatId = session.ChatId
 | 
			
		||||
				chatItem.UserId = session.UserId
 | 
			
		||||
				chatItem.RoleId = role.Id
 | 
			
		||||
				chatItem.ModelId = session.Model.Id
 | 
			
		||||
				if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
					chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
				} else {
 | 
			
		||||
					chatItem.Title = prompt
 | 
			
		||||
				}
 | 
			
		||||
				chatItem.Model = req.Model
 | 
			
		||||
				h.DB.Create(&chatItem)
 | 
			
		||||
			}
 | 
			
		||||
			h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		body, err := io.ReadAll(response.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("error with reading response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var res struct {
 | 
			
		||||
			Code    int    `json:"code"`
 | 
			
		||||
			Success bool   `json:"success"`
 | 
			
		||||
			Msg     string `json:"msg"`
 | 
			
		||||
		}
 | 
			
		||||
		err = json.Unmarshal(body, &res)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("error with decode response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		if !res.Success {
 | 
			
		||||
			utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg)
 | 
			
		||||
		}
 | 
			
		||||
		body, _ := io.ReadAll(response.Body)
 | 
			
		||||
		return fmt.Errorf("请求大模型 API 失败:%s", body)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
 
 | 
			
		||||
@@ -17,13 +17,10 @@ import (
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	req2 "github.com/imroc/req/v3"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
 | 
			
		||||
	req2 "github.com/imroc/req/v3"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// OPenAI 消息发送实现
 | 
			
		||||
@@ -178,126 +175,11 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			if message.Role == "" {
 | 
			
		||||
				message.Role = "assistant"
 | 
			
		||||
			}
 | 
			
		||||
			message.Content = strings.Join(contents, "")
 | 
			
		||||
			useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
			// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
			if h.App.SysConfig.EnableContext && toolCall == false {
 | 
			
		||||
				chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
				chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
				h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 追加聊天记录
 | 
			
		||||
			useContext := true
 | 
			
		||||
			if toolCall {
 | 
			
		||||
				useContext = false
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// for prompt
 | 
			
		||||
			promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.PromptMsg,
 | 
			
		||||
				Icon:       userVo.Avatar,
 | 
			
		||||
				Content:    template.HTMLEscapeString(prompt),
 | 
			
		||||
				Tokens:     promptToken,
 | 
			
		||||
				UseContext: useContext,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
			historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
			res := h.DB.Save(&historyUserMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 计算本次对话消耗的总 token 数量
 | 
			
		||||
			var replyTokens = 0
 | 
			
		||||
			if toolCall { // prompt + 函数名 + 参数 token
 | 
			
		||||
				tokens, _ := utils.CalcTokens(function.Name, req.Model)
 | 
			
		||||
				replyTokens += tokens
 | 
			
		||||
				tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
 | 
			
		||||
				replyTokens += tokens
 | 
			
		||||
			} else {
 | 
			
		||||
				replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
			}
 | 
			
		||||
			replyTokens += getTotalTokens(req)
 | 
			
		||||
 | 
			
		||||
			historyReplyMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.ReplyMsg,
 | 
			
		||||
				Icon:       role.Icon,
 | 
			
		||||
				Content:    h.extractImgUrl(message.Content),
 | 
			
		||||
				Tokens:     replyTokens,
 | 
			
		||||
				UseContext: useContext,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
			historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
			res = h.DB.Create(&historyReplyMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 更新用户算力
 | 
			
		||||
			h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
			
		||||
 | 
			
		||||
			// 保存当前会话
 | 
			
		||||
			var chatItem model.ChatItem
 | 
			
		||||
			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				chatItem.ChatId = session.ChatId
 | 
			
		||||
				chatItem.UserId = session.UserId
 | 
			
		||||
				chatItem.RoleId = role.Id
 | 
			
		||||
				chatItem.ModelId = session.Model.Id
 | 
			
		||||
				if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
					chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
				} else {
 | 
			
		||||
					chatItem.Title = prompt
 | 
			
		||||
				}
 | 
			
		||||
				chatItem.Model = req.Model
 | 
			
		||||
				h.DB.Create(&chatItem)
 | 
			
		||||
			}
 | 
			
		||||
			h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		body, err := io.ReadAll(response.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error())
 | 
			
		||||
			return fmt.Errorf("error with reading response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		var res types.ApiError
 | 
			
		||||
		err = json.Unmarshal(body, &res)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```")
 | 
			
		||||
			return fmt.Errorf("error with decode response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// OpenAI API 调用异常处理
 | 
			
		||||
		if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
 | 
			
		||||
			utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。")
 | 
			
		||||
			// 移除当前 API key
 | 
			
		||||
			h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
 | 
			
		||||
		} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
 | 
			
		||||
			utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。")
 | 
			
		||||
		} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
 | 
			
		||||
			logger.Error(res.Error.Message)
 | 
			
		||||
			utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
 | 
			
		||||
			h.App.ChatContexts.Delete(session.ChatId)
 | 
			
		||||
			return h.sendMessage(ctx, session, role, prompt, ws)
 | 
			
		||||
		} else {
 | 
			
		||||
			utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
 | 
			
		||||
		}
 | 
			
		||||
		body, _ := io.ReadAll(response.Body)
 | 
			
		||||
		return fmt.Errorf("请求 OpenAI API 失败:%s", body)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
 
 | 
			
		||||
@@ -10,18 +10,15 @@ package chatimpl
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/syndtr/goleveldb/leveldb/errors"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type qWenResp struct {
 | 
			
		||||
@@ -142,100 +139,11 @@ func (h *ChatHandler) sendQWenMessage(
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			if message.Role == "" {
 | 
			
		||||
				message.Role = "assistant"
 | 
			
		||||
			}
 | 
			
		||||
			message.Content = strings.Join(contents, "")
 | 
			
		||||
			useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
			// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
			if h.App.SysConfig.EnableContext {
 | 
			
		||||
				chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
				chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
				h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 追加聊天记录
 | 
			
		||||
			// for prompt
 | 
			
		||||
			promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.PromptMsg,
 | 
			
		||||
				Icon:       userVo.Avatar,
 | 
			
		||||
				Content:    template.HTMLEscapeString(prompt),
 | 
			
		||||
				Tokens:     promptToken,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
			historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
			res := h.DB.Save(&historyUserMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// for reply
 | 
			
		||||
			// 计算本次对话消耗的总 token 数量
 | 
			
		||||
			replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
			totalTokens := replyTokens + getTotalTokens(req)
 | 
			
		||||
			historyReplyMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.ReplyMsg,
 | 
			
		||||
				Icon:       role.Icon,
 | 
			
		||||
				Content:    message.Content,
 | 
			
		||||
				Tokens:     totalTokens,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
			historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
			res = h.DB.Create(&historyReplyMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 更新用户算力
 | 
			
		||||
			h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
			
		||||
 | 
			
		||||
			// 保存当前会话
 | 
			
		||||
			var chatItem model.ChatItem
 | 
			
		||||
			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				chatItem.ChatId = session.ChatId
 | 
			
		||||
				chatItem.UserId = session.UserId
 | 
			
		||||
				chatItem.RoleId = role.Id
 | 
			
		||||
				chatItem.ModelId = session.Model.Id
 | 
			
		||||
				if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
					chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
				} else {
 | 
			
		||||
					chatItem.Title = prompt
 | 
			
		||||
				}
 | 
			
		||||
				chatItem.Model = req.Model
 | 
			
		||||
				h.DB.Create(&chatItem)
 | 
			
		||||
			}
 | 
			
		||||
			h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		body, err := io.ReadAll(response.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("error with reading response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var res struct {
 | 
			
		||||
			Code int    `json:"error_code"`
 | 
			
		||||
			Msg  string `json:"error_msg"`
 | 
			
		||||
		}
 | 
			
		||||
		err = json.Unmarshal(body, &res)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("error with decode response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		utils.ReplyMessage(ws, "请求通义千问大模型 API 失败:"+res.Msg)
 | 
			
		||||
		body, _ := io.ReadAll(response.Body)
 | 
			
		||||
		return fmt.Errorf("请求大模型 API 失败:%s", body)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
 
 | 
			
		||||
@@ -21,13 +21,11 @@ import (
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type xunFeiResp struct {
 | 
			
		||||
@@ -181,89 +179,10 @@ func (h *ChatHandler) sendXunFeiMessage(
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 消息发送成功
 | 
			
		||||
	if len(contents) > 0 {
 | 
			
		||||
		if message.Role == "" {
 | 
			
		||||
			message.Role = "assistant"
 | 
			
		||||
		}
 | 
			
		||||
		message.Content = strings.Join(contents, "")
 | 
			
		||||
		useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
		// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
		if h.App.SysConfig.EnableContext {
 | 
			
		||||
			chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
			chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
			h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 追加聊天记录
 | 
			
		||||
		// for prompt
 | 
			
		||||
		promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error(err)
 | 
			
		||||
		}
 | 
			
		||||
		historyUserMsg := model.ChatMessage{
 | 
			
		||||
			UserId:     userVo.Id,
 | 
			
		||||
			ChatId:     session.ChatId,
 | 
			
		||||
			RoleId:     role.Id,
 | 
			
		||||
			Type:       types.PromptMsg,
 | 
			
		||||
			Icon:       userVo.Avatar,
 | 
			
		||||
			Content:    template.HTMLEscapeString(prompt),
 | 
			
		||||
			Tokens:     promptToken,
 | 
			
		||||
			UseContext: true,
 | 
			
		||||
			Model:      req.Model,
 | 
			
		||||
		}
 | 
			
		||||
		historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
		historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
		res := h.DB.Save(&historyUserMsg)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// for reply
 | 
			
		||||
		// 计算本次对话消耗的总 token 数量
 | 
			
		||||
		replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
		totalTokens := replyTokens + getTotalTokens(req)
 | 
			
		||||
		historyReplyMsg := model.ChatMessage{
 | 
			
		||||
			UserId:     userVo.Id,
 | 
			
		||||
			ChatId:     session.ChatId,
 | 
			
		||||
			RoleId:     role.Id,
 | 
			
		||||
			Type:       types.ReplyMsg,
 | 
			
		||||
			Icon:       role.Icon,
 | 
			
		||||
			Content:    message.Content,
 | 
			
		||||
			Tokens:     totalTokens,
 | 
			
		||||
			UseContext: true,
 | 
			
		||||
			Model:      req.Model,
 | 
			
		||||
		}
 | 
			
		||||
		historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
		historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
		res = h.DB.Create(&historyReplyMsg)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 更新用户算力
 | 
			
		||||
		h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
			
		||||
 | 
			
		||||
		// 保存当前会话
 | 
			
		||||
		var chatItem model.ChatItem
 | 
			
		||||
		res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			chatItem.ChatId = session.ChatId
 | 
			
		||||
			chatItem.UserId = session.UserId
 | 
			
		||||
			chatItem.RoleId = role.Id
 | 
			
		||||
			chatItem.ModelId = session.Model.Id
 | 
			
		||||
			if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
				chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
			} else {
 | 
			
		||||
				chatItem.Title = prompt
 | 
			
		||||
			}
 | 
			
		||||
			chatItem.Model = req.Model
 | 
			
		||||
			h.DB.Create(&chatItem)
 | 
			
		||||
		}
 | 
			
		||||
		h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -304,7 +304,7 @@ func main() {
 | 
			
		||||
			group.GET("config/get", h.Get)
 | 
			
		||||
			group.POST("active", h.Active)
 | 
			
		||||
			group.GET("config/get/license", h.GetLicense)
 | 
			
		||||
			group.GET("config/get/draw", h.GetDrawingConfig)
 | 
			
		||||
			group.GET("config/get/app", h.GetAppConfig)
 | 
			
		||||
			group.POST("config/update/draw", h.SaveDrawingConfig)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user