diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 237db83f..04d57355 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -53,3 +53,10 @@ type ApiError struct { const PromptMsg = "prompt" // prompt message const ReplyMsg = "reply" // reply message + +var ModelToTokens = map[string]int{ + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-16k": 16384, + "gpt-4": 8192, + "gpt-4-32k": 32768, +} diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go index 84f7d185..b908872e 100644 --- a/api/handler/chat_handler.go +++ b/api/handler/chat_handler.go @@ -180,21 +180,37 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession if h.App.ChatContexts.Has(session.ChatId) { chatCtx = h.App.ChatContexts.Get(session.ChatId) } else { - // 加载角色信息 + // calculate the tokens of current request, to prevent to exceeding the max tokens num + tokens := req.MaxTokens + for _, f := range types.InnerFunctions { + tks, _ := utils.CalcTokens(utils.JsonEncode(f), req.Model) + tokens += tks + } + + // loading the role context var messages []types.Message err := utils.JsonDecode(role.Context, &messages) if err == nil { for _, v := range messages { + tks, _ := utils.CalcTokens(v.Content, req.Model) + if tokens+tks >= types.ModelToTokens[req.Model] { + break + } + tokens += tks chatCtx = append(chatCtx, v) } } - // 加载最近的聊天记录作为聊天上下文 + // loading recent chat history as chat context if chatConfig.ContextDeep > 0 { var historyMessages []model.HistoryMessage - res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(2).Order("created_at desc").Find(&historyMessages) + res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("created_at desc").Find(&historyMessages) if res.Error == nil { for _, msg := range historyMessages { + if tokens+msg.Tokens >= types.ModelToTokens[session.Model] { + break + } + tokens += msg.Tokens ms := types.Message{Role: "user", Content: msg.Content} if msg.Type == types.ReplyMsg { ms.Role = "assistant" @@ -204,7 +220,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession } } } - logger.Debugf("聊天上下文:%+v", chatCtx) } reqMgs := make([]interface{}, 0) @@ -459,13 +474,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession } else if strings.Contains(res.Error.Message, "You exceeded your current quota") { replyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。") } else if strings.Contains(res.Error.Message, "This model's maximum context length") { - replyMessage(ws, "当前会话上下文长度超出限制,已为您删减会话上下文!") - // 只保留最近的三条记录 - chatContext := h.App.ChatContexts.Get(session.ChatId) - if len(chatContext) > 3 { - chatContext = chatContext[len(chatContext)-3:] - } - h.App.ChatContexts.Put(session.ChatId, chatContext) + logger.Error(res.Error.Message) + replyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!") + h.App.ChatContexts.Delete(session.ChatId) return h.sendMessage(ctx, session, role, prompt, ws) } else { replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message) diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index 48559d49..12326e6d 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -274,7 +274,7 @@ const title = ref('ChatGPT-智能助手'); const logo = 'images/logo.png'; const rewardImg = ref('images/reward.png') const models = ref([]) -const model = ref('gpt-3.5-turbo') +const model = ref('gpt-3.5-turbo-16k') const chatData = ref([]); const allChats = ref([]); // 会话列表 const chatList = ref(allChats.value);