fix: add lock map data structure, fixed bug for 'concurrent map writes'

This commit is contained in:
RockYang
2023-06-16 15:32:11 +08:00
parent c9875d24b4
commit 111572e3f2
10 changed files with 191 additions and 54 deletions

View File

@@ -48,12 +48,13 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
logger.Error(err)
return
}
sessionId := c.Query("sessionId")
roleId := param.GetInt(c, "roleId", 0)
chatId := c.Query("chatId")
sessionId := c.Query("session_id")
roleId := param.GetInt(c, "role_id", 0)
chatId := c.Query("chat_id")
chatModel := c.Query("model")
session, ok := h.app.ChatSession[sessionId]
if !ok {
session := h.app.ChatSession.Get(sessionId)
if session.SessionId == "" {
logger.Info("用户未登录")
c.Abort()
return
@@ -70,7 +71,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session.ChatId = chatId
session.Model = chatModel
logger.Infof("New websocket connected, IP: %s, Username: %s", c.Request.RemoteAddr, session.Username)
client := core.NewWsClient(ws)
client := types.NewWsClient(ws)
var chatRole model.ChatRole
res = h.db.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
@@ -80,21 +81,21 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
}
// 保存会话连接
h.app.ChatClients[chatId] = client
h.app.ChatClients.Put(sessionId, client)
go func() {
for {
_, message, err := client.Receive()
if err != nil {
logger.Error(err)
client.Close()
delete(h.app.ChatClients, chatId)
delete(h.app.ReqCancelFunc, chatId)
h.app.ChatClients.Delete(sessionId)
h.app.ReqCancelFunc.Delete(sessionId)
return
}
logger.Info("Receive a message: ", string(message))
//replyMessage(client, "这是一条测试消息!")
ctx, cancel := context.WithCancel(context.Background())
h.app.ReqCancelFunc[chatId] = cancel
h.app.ReqCancelFunc.Put(sessionId, cancel)
// 回复消息
err = h.sendMessage(ctx, session, chatRole, string(message), client)
if err != nil {
@@ -109,7 +110,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
}
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws core.Client) error {
func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws types.Client) error {
promptCreatedAt := time.Now() // 记录提问时间
var user model.User
@@ -152,8 +153,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
// 加载聊天上下文
var chatCtx []types.Message
if userVo.ChatConfig.EnableContext {
if v, ok := h.app.ChatContexts[session.ChatId]; ok {
chatCtx = v
if h.app.ChatContexts.Has(session.ChatId) {
chatCtx = h.app.ChatContexts.Get(session.ChatId)
} else {
// 加载角色信息
var messages []types.Message
@@ -262,7 +263,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
if userVo.ChatConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.app.ChatContexts[session.ChatId] = chatCtx
h.app.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
@@ -348,9 +349,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
replyMessage(ws, "当前会话上下文长度超出限制,已为您删减会话上下文!")
// 只保留最近的三条记录
chatContext := h.app.ChatContexts[session.ChatId]
chatContext := h.app.ChatContexts.Get(session.ChatId)
chatContext = chatContext[len(chatContext)-3:]
h.app.ChatContexts[session.ChatId] = chatContext
h.app.ChatContexts.Put(session.ChatId, chatContext)
return h.sendMessage(ctx, session, role, prompt, ws)
} else {
replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
@@ -410,20 +411,20 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
}
// 回复客户片段端消息
func replyChunkMessage(client core.Client, message types.WsMessage) {
func replyChunkMessage(client types.Client, message types.WsMessage) {
msg, err := json.Marshal(message)
if err != nil {
logger.Errorf("Error for decoding json data: %v", err.Error())
return
}
err = client.(*core.WsClient).Send(msg)
err = client.(*types.WsClient).Send(msg)
if err != nil {
logger.Errorf("Error for reply message: %v", err.Error())
}
}
// 回复客户端一条完整的消息
func replyMessage(ws core.Client, message string) {
func replyMessage(ws types.Client, message string) {
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
@@ -444,10 +445,10 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
// StopGenerate 停止生成
func (h *ChatHandler) StopGenerate(c *gin.Context) {
chatId := c.Query("chat_id")
if cancel, ok := h.app.ReqCancelFunc[chatId]; ok {
cancel()
delete(h.app.ReqCancelFunc, chatId)
sessionId := c.Query("session_id")
if h.app.ReqCancelFunc.Has(sessionId) {
h.app.ReqCancelFunc.Get(sessionId)()
h.app.ReqCancelFunc.Delete(sessionId)
}
resp.SUCCESS(c, types.OkMsg)
}