The 'stop generate' and 'regenerate response' function is ready

This commit is contained in:
RockYang
2023-04-11 18:58:27 +08:00
parent a2cf97b039
commit 1db20959e7
5 changed files with 237 additions and 47 deletions

View File

@@ -3,6 +3,7 @@ package server
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -59,12 +60,13 @@ func (s *Server) ChatHandle(c *gin.Context) {
delete(s.ChatClients, sessionId)
return
}
logger.Info("Receive a message: ", string(message))
//replyMessage(client, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!", false)
//replyMessage(client, "![](images/wx.png)", true)
// TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录
err = s.sendMessage(session, chatRole, string(message), client, false)
ctx, cancel := context.WithCancel(context.Background())
s.ReqCancelFunc[sessionId] = cancel
// 回复消息
err = s.sendMessage(ctx, session, chatRole, string(message), client, false)
if err != nil {
logger.Error(err)
}
@@ -73,7 +75,13 @@ func (s *Server) ChatHandle(c *gin.Context) {
}
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, prompt string, ws Client, resetContext bool) error {
func (s *Server) sendMessage(ctx context.Context, session types.ChatSession, role types.ChatRole, prompt string, ws Client, resetContext bool) error {
cancel := s.ReqCancelFunc[session.SessionId]
defer func() {
cancel()
delete(s.ReqCancelFunc, session.SessionId)
}()
user, err := GetUser(session.Username)
if err != nil {
replyMessage(ws, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!", false)
@@ -98,38 +106,38 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
replyMessage(ws, "![](images/start.png)", true)
return nil
}
var r = types.ApiRequest{
var req = types.ApiRequest{
Model: s.Config.Chat.Model,
Temperature: s.Config.Chat.Temperature,
MaxTokens: s.Config.Chat.MaxTokens,
Stream: true,
}
var context []types.Message
var chatCtx []types.Message
var ctxKey = fmt.Sprintf("%s-%s", session.SessionId, role.Key)
if v, ok := s.ChatContexts[ctxKey]; ok && s.Config.Chat.EnableContext {
context = v.Messages
chatCtx = v.Messages
} else {
context = role.Context
chatCtx = role.Context
}
if s.DebugMode {
logger.Infof("会话上下文:%+v", context)
logger.Infof("会话上下文:%+v", chatCtx)
}
req.Messages = append(chatCtx, types.Message{
Role: "user",
Content: prompt,
})
// 创建 HttpClient 请求对象
var client *http.Client
var retryCount = 5
var retryCount = 5 // 重试次数
var response *http.Response
var apiKey string
var failedKey = ""
var failedProxyURL = ""
for retryCount > 0 {
r.Messages = append(context, types.Message{
Role: "user",
Content: prompt,
})
requestBody, err := json.Marshal(r)
requestBody, err := json.Marshal(req)
if err != nil {
return err
}
@@ -139,6 +147,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
return err
}
request = request.WithContext(ctx)
request.Header.Add("Content-Type", "application/json")
proxyURL := s.getProxyURL(failedProxyURL)
@@ -164,7 +173,10 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
response, err = client.Do(request)
if err == nil {
break
} else if strings.Contains(err.Error(), "context canceled") {
return errors.New("用户取消了请求:" + prompt)
} else {
logger.Error("HTTP API 请求失败:" + err.Error())
failedKey = apiKey
failedProxyURL = proxyURL
}
@@ -205,14 +217,14 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
_ = utils.SaveConfig(s.Config, s.ConfigPath)
// 重发当前消息
return s.sendMessage(session, role, prompt, ws, false)
return s.sendMessage(ctx, session, role, prompt, ws, false)
// 上下文超出长度了
} else if strings.Contains(line, "This model's maximum context length is 4097 tokens") {
logger.Infof("会话上下文长度超出限制, Username: %s", user.Name)
// 重置上下文,重发当前消息
delete(s.ChatContexts, ctxKey)
return s.sendMessage(session, role, prompt, ws, true)
return s.sendMessage(ctx, session, role, prompt, ws, true)
} else if !strings.Contains(line, "data:") {
continue
}
@@ -246,7 +258,18 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
IsHelloMsg: false,
})
}
}
// 监控取消信号
select {
case <-ctx.Done():
// 结束输出
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false})
_ = response.Body.Close()
return errors.New("用户取消了请求:" + prompt)
default:
continue
}
} // end for
_ = response.Body.Close() // 关闭资源
// 消息发送成功
@@ -260,27 +283,26 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
if message.Role == "" {
message.Role = "assistant"
}
// 追加上下文消息
useMsg := types.Message{Role: "user", Content: prompt}
context = append(context, useMsg)
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息
if s.Config.Chat.EnableContext {
context = append(context, message)
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
s.ChatContexts[ctxKey] = types.ChatContext{
Messages: context,
Messages: chatCtx,
LastAccessTime: time.Now().Unix(),
}
}
// 追加历史消息
if user.EnableHistory {
err = AppendChatHistory(user.Name, role.Key, useMsg)
err = AppendChatHistory(user.Name, role.Key, useMsg) // 提问消息
if err != nil {
return err
}
err = AppendChatHistory(user.Name, role.Key, message)
err = AppendChatHistory(user.Name, role.Key, message) // 回复消息
}
}
@@ -431,3 +453,12 @@ func (s *Server) ClearHistoryHandle(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success})
}
// StopGenerateHandle 停止生成
func (s *Server) StopGenerateHandle(c *gin.Context) {
sessionId := c.GetHeader(types.TokenName)
cancel := s.ReqCancelFunc[sessionId]
cancel()
delete(s.ReqCancelFunc, sessionId)
c.JSON(http.StatusOK, types.BizVo{Code: types.Success})
}