The 'stop generate' and 'regenerate response' function is ready

This commit is contained in:
RockYang 2023-04-11 18:58:27 +08:00
parent a2cf97b039
commit 1db20959e7
5 changed files with 237 additions and 47 deletions

View File

@ -3,6 +3,7 @@ package server
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -59,12 +60,13 @@ func (s *Server) ChatHandle(c *gin.Context) {
delete(s.ChatClients, sessionId) delete(s.ChatClients, sessionId)
return return
} }
logger.Info("Receive a message: ", string(message)) logger.Info("Receive a message: ", string(message))
//replyMessage(client, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!", false) //replyMessage(client, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!", false)
//replyMessage(client, "![](images/wx.png)", true) //replyMessage(client, "![](images/wx.png)", true)
// TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录 ctx, cancel := context.WithCancel(context.Background())
err = s.sendMessage(session, chatRole, string(message), client, false) s.ReqCancelFunc[sessionId] = cancel
// 回复消息
err = s.sendMessage(ctx, session, chatRole, string(message), client, false)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
} }
@ -73,7 +75,13 @@ func (s *Server) ChatHandle(c *gin.Context) {
} }
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端 // 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, prompt string, ws Client, resetContext bool) error { func (s *Server) sendMessage(ctx context.Context, session types.ChatSession, role types.ChatRole, prompt string, ws Client, resetContext bool) error {
cancel := s.ReqCancelFunc[session.SessionId]
defer func() {
cancel()
delete(s.ReqCancelFunc, session.SessionId)
}()
user, err := GetUser(session.Username) user, err := GetUser(session.Username)
if err != nil { if err != nil {
replyMessage(ws, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!", false) replyMessage(ws, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!", false)
@ -98,38 +106,38 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
replyMessage(ws, "![](images/start.png)", true) replyMessage(ws, "![](images/start.png)", true)
return nil return nil
} }
var r = types.ApiRequest{ var req = types.ApiRequest{
Model: s.Config.Chat.Model, Model: s.Config.Chat.Model,
Temperature: s.Config.Chat.Temperature, Temperature: s.Config.Chat.Temperature,
MaxTokens: s.Config.Chat.MaxTokens, MaxTokens: s.Config.Chat.MaxTokens,
Stream: true, Stream: true,
} }
var context []types.Message var chatCtx []types.Message
var ctxKey = fmt.Sprintf("%s-%s", session.SessionId, role.Key) var ctxKey = fmt.Sprintf("%s-%s", session.SessionId, role.Key)
if v, ok := s.ChatContexts[ctxKey]; ok && s.Config.Chat.EnableContext { if v, ok := s.ChatContexts[ctxKey]; ok && s.Config.Chat.EnableContext {
context = v.Messages chatCtx = v.Messages
} else { } else {
context = role.Context chatCtx = role.Context
} }
if s.DebugMode { if s.DebugMode {
logger.Infof("会话上下文:%+v", context) logger.Infof("会话上下文:%+v", chatCtx)
} }
req.Messages = append(chatCtx, types.Message{
Role: "user",
Content: prompt,
})
// 创建 HttpClient 请求对象 // 创建 HttpClient 请求对象
var client *http.Client var client *http.Client
var retryCount = 5 var retryCount = 5 // 重试次数
var response *http.Response var response *http.Response
var apiKey string var apiKey string
var failedKey = "" var failedKey = ""
var failedProxyURL = "" var failedProxyURL = ""
for retryCount > 0 { for retryCount > 0 {
r.Messages = append(context, types.Message{ requestBody, err := json.Marshal(req)
Role: "user",
Content: prompt,
})
requestBody, err := json.Marshal(r)
if err != nil { if err != nil {
return err return err
} }
@ -139,6 +147,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
return err return err
} }
request = request.WithContext(ctx)
request.Header.Add("Content-Type", "application/json") request.Header.Add("Content-Type", "application/json")
proxyURL := s.getProxyURL(failedProxyURL) proxyURL := s.getProxyURL(failedProxyURL)
@ -164,7 +173,10 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
response, err = client.Do(request) response, err = client.Do(request)
if err == nil { if err == nil {
break break
} else if strings.Contains(err.Error(), "context canceled") {
return errors.New("用户取消了请求:" + prompt)
} else { } else {
logger.Error("HTTP API 请求失败:" + err.Error())
failedKey = apiKey failedKey = apiKey
failedProxyURL = proxyURL failedProxyURL = proxyURL
} }
@ -205,14 +217,14 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
_ = utils.SaveConfig(s.Config, s.ConfigPath) _ = utils.SaveConfig(s.Config, s.ConfigPath)
// 重发当前消息 // 重发当前消息
return s.sendMessage(session, role, prompt, ws, false) return s.sendMessage(ctx, session, role, prompt, ws, false)
// 上下文超出长度了 // 上下文超出长度了
} else if strings.Contains(line, "This model's maximum context length is 4097 tokens") { } else if strings.Contains(line, "This model's maximum context length is 4097 tokens") {
logger.Infof("会话上下文长度超出限制, Username: %s", user.Name) logger.Infof("会话上下文长度超出限制, Username: %s", user.Name)
// 重置上下文,重发当前消息 // 重置上下文,重发当前消息
delete(s.ChatContexts, ctxKey) delete(s.ChatContexts, ctxKey)
return s.sendMessage(session, role, prompt, ws, true) return s.sendMessage(ctx, session, role, prompt, ws, true)
} else if !strings.Contains(line, "data:") { } else if !strings.Contains(line, "data:") {
continue continue
} }
@ -246,7 +258,18 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
IsHelloMsg: false, IsHelloMsg: false,
}) })
} }
// 监控取消信号
select {
case <-ctx.Done():
// 结束输出
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false})
_ = response.Body.Close()
return errors.New("用户取消了请求:" + prompt)
default:
continue
} }
} // end for
_ = response.Body.Close() // 关闭资源 _ = response.Body.Close() // 关闭资源
// 消息发送成功 // 消息发送成功
@ -260,27 +283,26 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro
if message.Role == "" { if message.Role == "" {
message.Role = "assistant" message.Role = "assistant"
} }
// 追加上下文消息
useMsg := types.Message{Role: "user", Content: prompt}
context = append(context, useMsg)
message.Content = strings.Join(contents, "") message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息 // 更新上下文消息
if s.Config.Chat.EnableContext { if s.Config.Chat.EnableContext {
context = append(context, message) chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
s.ChatContexts[ctxKey] = types.ChatContext{ s.ChatContexts[ctxKey] = types.ChatContext{
Messages: context, Messages: chatCtx,
LastAccessTime: time.Now().Unix(), LastAccessTime: time.Now().Unix(),
} }
} }
// 追加历史消息 // 追加历史消息
if user.EnableHistory { if user.EnableHistory {
err = AppendChatHistory(user.Name, role.Key, useMsg) err = AppendChatHistory(user.Name, role.Key, useMsg) // 提问消息
if err != nil { if err != nil {
return err return err
} }
err = AppendChatHistory(user.Name, role.Key, message) err = AppendChatHistory(user.Name, role.Key, message) // 回复消息
} }
} }
@ -431,3 +453,12 @@ func (s *Server) ClearHistoryHandle(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success})
} }
// StopGenerateHandle 停止生成
func (s *Server) StopGenerateHandle(c *gin.Context) {
sessionId := c.GetHeader(types.TokenName)
cancel := s.ReqCancelFunc[sessionId]
cancel()
delete(s.ReqCancelFunc, sessionId)
c.JSON(http.StatusOK, types.BizVo{Code: types.Success})
}

View File

@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"embed" "embed"
"encoding/json" "encoding/json"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
@ -42,6 +43,7 @@ type Server struct {
ChatSession map[string]types.ChatSession //map[sessionId]User ChatSession map[string]types.ChatSession //map[sessionId]User
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内 ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
ChatClients map[string]*WsClient // Websocket 连接集合 ChatClients map[string]*WsClient // Websocket 连接集合
ReqCancelFunc map[string]context.CancelFunc // HttpClient 请求取消 handle function
DebugMode bool // 是否开启调试模式 DebugMode bool // 是否开启调试模式
} }
@ -67,6 +69,7 @@ func NewServer(configPath string) (*Server, error) {
ChatContexts: make(map[string]types.ChatContext, 16), ChatContexts: make(map[string]types.ChatContext, 16),
ChatSession: make(map[string]types.ChatSession), ChatSession: make(map[string]types.ChatSession),
ChatClients: make(map[string]*WsClient), ChatClients: make(map[string]*WsClient),
ReqCancelFunc: make(map[string]context.CancelFunc),
ApiKeyAccessStat: make(map[string]int64), ApiKeyAccessStat: make(map[string]int64),
}, nil }, nil
} }
@ -83,17 +86,18 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
engine.Use(AuthorizeMiddleware(s)) engine.Use(AuthorizeMiddleware(s))
engine.Use(Recover) engine.Use(Recover)
engine.POST("/test", s.TestHandle) engine.POST("test", s.TestHandle)
engine.GET("/api/session/get", s.GetSessionHandle) engine.GET("api/session/get", s.GetSessionHandle)
engine.POST("/api/login", s.LoginHandle) engine.POST("api/login", s.LoginHandle)
engine.POST("/api/logout", s.LogoutHandle) engine.POST("api/logout", s.LogoutHandle)
engine.Any("/api/chat", s.ChatHandle) engine.Any("api/chat", s.ChatHandle)
engine.POST("api/chat/stop", s.StopGenerateHandle)
engine.POST("api/chat/history", s.GetChatHistoryHandle) engine.POST("api/chat/history", s.GetChatHistoryHandle)
engine.POST("api/chat/history/clear", s.ClearHistoryHandle) engine.POST("api/chat/history/clear", s.ClearHistoryHandle)
engine.POST("/api/config/set", s.ConfigSetHandle) engine.POST("api/config/set", s.ConfigSetHandle)
engine.GET("/api/config/chat-roles/get", s.GetChatRoleListHandle) engine.GET("api/config/chat-roles/get", s.GetChatRoleListHandle)
engine.GET("/api/config/chat-roles/add", s.AddChatRoleHandle) engine.GET("api/config/chat-roles/add", s.AddChatRoleHandle)
engine.POST("api/config/user/add", s.AddUserHandle) engine.POST("api/config/user/add", s.AddUserHandle)
engine.POST("api/config/user/batch-add", s.BatchAddUserHandle) engine.POST("api/config/user/batch-add", s.BatchAddUserHandle)
engine.POST("api/config/user/set", s.SetUserHandle) engine.POST("api/config/user/set", s.SetUserHandle)

View File

@ -1,10 +1,73 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"io"
"net/http"
"time" "time"
) )
func main() { func main() {
ctx, cancel := context.WithCancel(context.Background())
http.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) {
cancel()
_, _ = fmt.Fprintf(w, "请求取消!")
})
go func() {
err := http.ListenAndServe(":9999", nil)
if err != nil {
return
}
}()
testHttpClient(ctx)
}
// Http client 取消操作
func testHttpClient(ctx context.Context) {
req, err := http.NewRequest("GET", "http://localhost:2345", nil)
if err != nil {
fmt.Println(err)
return
}
req = req.WithContext(ctx)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
fmt.Println(err)
return
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
for {
time.Sleep(time.Second)
fmt.Println(time.Now())
select {
case <-ctx.Done():
fmt.Println("取消退出")
return
default:
continue
}
}
if err != nil {
fmt.Println(err)
return
}
fmt.Println(string(body))
}
func testDate() {
fmt.Println(time.Unix(1683336167, 0).Format("2006-01-02 15:04:05")) fmt.Println(time.Unix(1683336167, 0).Format("2006-01-02 15:04:05"))
} }

View File

@ -1,2 +1,2 @@
VUE_APP_API_HOST=https://ai.r9it.com VUE_APP_API_HOST=https://www.chat-plus.net
VUE_APP_WS_HOST=wss://ai.r9it.com VUE_APP_WS_HOST=wss://www.chat-plus.net

View File

@ -81,9 +81,26 @@
:icon="chat.icon" :icon="chat.icon"
:content="chat.content"/> :content="chat.content"/>
</div> </div>
</div><!-- end chat box --> </div><!-- end chat box -->
<div class="re-generate">
<div class="btn-box">
<el-button type="info" v-if="showStopGenerate" @click="stopGenerate" plain>
<el-icon>
<VideoPause/>
</el-icon>
停止生成
</el-button>
<el-button type="info" v-if="showReGenerate" @click="reGenerate" plain>
<el-icon>
<RefreshRight/>
</el-icon>
重新生成
</el-button>
</div>
</div>
<el-row class="chat-tool-box"> <el-row class="chat-tool-box">
<el-tooltip <el-tooltip
class="box-item" class="box-item"
@ -159,7 +176,17 @@ import ChatPrompt from "@/components/plus/ChatPrompt.vue";
import ChatReply from "@/components/plus/ChatReply.vue"; import ChatReply from "@/components/plus/ChatReply.vue";
import {isMobile, randString} from "@/utils/libs"; import {isMobile, randString} from "@/utils/libs";
import {ElMessage, ElMessageBox} from 'element-plus' import {ElMessage, ElMessageBox} from 'element-plus'
import {Tools, Lock, Delete, Picture, Search, ArrowDown, Monitor} from '@element-plus/icons-vue' import {
Tools,
Lock,
Delete,
Picture,
Search,
ArrowDown,
Monitor,
VideoPause,
RefreshRight
} from '@element-plus/icons-vue'
import ConfigDialog from '@/components/ConfigDialog.vue' import ConfigDialog from '@/components/ConfigDialog.vue'
import {httpPost, httpGet} from "@/utils/http"; import {httpPost, httpGet} from "@/utils/http";
import {getSessionId, setLoginUser} from "@/utils/storage"; import {getSessionId, setLoginUser} from "@/utils/storage";
@ -168,7 +195,20 @@ import 'highlight.js/styles/a11y-dark.css'
export default defineComponent({ export default defineComponent({
name: "ChatPlus", name: "ChatPlus",
components: {ArrowDown, Search, ChatPrompt, ChatReply, Tools, Lock, Delete, Picture, Monitor, ConfigDialog}, components: {
RefreshRight,
VideoPause,
ArrowDown,
Search,
ChatPrompt,
ChatReply,
Tools,
Lock,
Delete,
Picture,
Monitor,
ConfigDialog
},
data() { data() {
return { return {
title: 'ChatGPT 控制台', title: 'ChatGPT 控制台',
@ -184,6 +224,11 @@ export default defineComponent({
replyIcon: 'images/avatar/gpt.png', // replyIcon: 'images/avatar/gpt.png', //
roleName: "", // roleName: "", //
showStopGenerate: false, //
showReGenerate: false, //
canReGenerate: false, //
previousText: '', //
lineBuffer: '', // lineBuffer: '', //
connectingMessageBox: null, // connectingMessageBox: null, //
errorMessage: null, // errorMessage: null, //
@ -260,8 +305,15 @@ export default defineComponent({
content: "", content: "",
cursor: true cursor: true
}); });
} else if (data.type === 'end') { if (data['is_hello_msg'] !== true) {
this.canReGenerate = true;
}
} else if (data.type === 'end') { //
this.sending = false; this.sending = false;
if (data['is_hello_mgs'] !== true) {
this.showReGenerate = true;
}
this.showStopGenerate = false;
this.lineBuffer = ''; // this.lineBuffer = ''; //
} else { } else {
this.lineBuffer += data.content; this.lineBuffer += data.content;
@ -372,6 +424,7 @@ export default defineComponent({
inputKeyDown: function (e) { inputKeyDown: function (e) {
if (e.keyCode === 13) { if (e.keyCode === 13) {
if (this.sending) { if (this.sending) {
ElMessage.warning("AI 正在作答中,请稍后...");
e.preventDefault(); e.preventDefault();
} else { } else {
this.sendMessage(); this.sendMessage();
@ -390,10 +443,8 @@ export default defineComponent({
target.blur(); target.blur();
} }
if (this.inputValue.trim().length === 0) { if (this.inputValue.trim().length === 0 || this.sending) {
return false; return false;
} else if (this.sending) {
ElMessage.warning("AI 正在作答中请稍后...")
} }
// //
@ -405,8 +456,11 @@ export default defineComponent({
}); });
this.sending = true; this.sending = true;
this.showStopGenerate = true;
this.showReGenerate = false;
this.socket.send(this.inputValue); this.socket.send(this.inputValue);
this.$refs["text-input"].blur(); this.$refs["text-input"].blur();
this.previousText = this.inputValue;
this.inputValue = ''; this.inputValue = '';
// textarea // textarea
setTimeout(() => this.$refs["text-input"].focus(), 100); setTimeout(() => this.$refs["text-input"].focus(), 100);
@ -494,6 +548,26 @@ export default defineComponent({
}).catch(() => { }).catch(() => {
ElMessage.error("注销失败"); ElMessage.error("注销失败");
}) })
},
//
stopGenerate: function () {
this.showStopGenerate = false;
httpPost("/api/chat/stop").then(() => {
console.log("stopped generate.")
this.sending = false;
if (this.canReGenerate) {
this.showReGenerate = true;
}
})
},
//
reGenerate: function () {
this.sending = true;
this.showStopGenerate = true;
this.showReGenerate = false;
this.socket.send('重新生成上述问题的答案:' + this.previousText);
} }
}, },
@ -676,6 +750,24 @@ export default defineComponent({
} }
} }
.re-generate {
position: relative;
display: flex;
justify-content: center;
.btn-box {
position absolute
bottom 10px;
.el-button {
.el-icon {
margin-right 5px;
}
}
}
}
.chat-tool-box { .chat-tool-box {
padding 10px; padding 10px;
border-top: 1px solid #2F3032 border-top: 1px solid #2F3032