fix: add lock map data structure, fixed bug for 'concurrent map writes'

This commit is contained in:
RockYang
2023-06-16 15:32:11 +08:00
parent c9875d24b4
commit 111572e3f2
10 changed files with 191 additions and 54 deletions

View File

@@ -48,12 +48,13 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
logger.Error(err)
return
}
sessionId := c.Query("sessionId")
roleId := param.GetInt(c, "roleId", 0)
chatId := c.Query("chatId")
sessionId := c.Query("session_id")
roleId := param.GetInt(c, "role_id", 0)
chatId := c.Query("chat_id")
chatModel := c.Query("model")
session, ok := h.app.ChatSession[sessionId]
if !ok {
session := h.app.ChatSession.Get(sessionId)
if session.SessionId == "" {
logger.Info("用户未登录")
c.Abort()
return
@@ -70,7 +71,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session.ChatId = chatId
session.Model = chatModel
logger.Infof("New websocket connected, IP: %s, Username: %s", c.Request.RemoteAddr, session.Username)
client := core.NewWsClient(ws)
client := types.NewWsClient(ws)
var chatRole model.ChatRole
res = h.db.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
@@ -80,21 +81,21 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
}
// 保存会话连接
h.app.ChatClients[chatId] = client
h.app.ChatClients.Put(sessionId, client)
go func() {
for {
_, message, err := client.Receive()
if err != nil {
logger.Error(err)
client.Close()
delete(h.app.ChatClients, chatId)
delete(h.app.ReqCancelFunc, chatId)
h.app.ChatClients.Delete(sessionId)
h.app.ReqCancelFunc.Delete(sessionId)
return
}
logger.Info("Receive a message: ", string(message))
//replyMessage(client, "这是一条测试消息!")
ctx, cancel := context.WithCancel(context.Background())
h.app.ReqCancelFunc[chatId] = cancel
h.app.ReqCancelFunc.Put(sessionId, cancel)
// 回复消息
err = h.sendMessage(ctx, session, chatRole, string(message), client)
if err != nil {
@@ -109,7 +110,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
}
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws core.Client) error {
func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws types.Client) error {
promptCreatedAt := time.Now() // 记录提问时间
var user model.User
@@ -152,8 +153,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
// 加载聊天上下文
var chatCtx []types.Message
if userVo.ChatConfig.EnableContext {
if v, ok := h.app.ChatContexts[session.ChatId]; ok {
chatCtx = v
if h.app.ChatContexts.Has(session.ChatId) {
chatCtx = h.app.ChatContexts.Get(session.ChatId)
} else {
// 加载角色信息
var messages []types.Message
@@ -262,7 +263,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
if userVo.ChatConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.app.ChatContexts[session.ChatId] = chatCtx
h.app.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
@@ -348,9 +349,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
replyMessage(ws, "当前会话上下文长度超出限制,已为您删减会话上下文!")
// 只保留最近的三条记录
chatContext := h.app.ChatContexts[session.ChatId]
chatContext := h.app.ChatContexts.Get(session.ChatId)
chatContext = chatContext[len(chatContext)-3:]
h.app.ChatContexts[session.ChatId] = chatContext
h.app.ChatContexts.Put(session.ChatId, chatContext)
return h.sendMessage(ctx, session, role, prompt, ws)
} else {
replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
@@ -410,20 +411,20 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
}
// 回复客户片段端消息
func replyChunkMessage(client core.Client, message types.WsMessage) {
func replyChunkMessage(client types.Client, message types.WsMessage) {
msg, err := json.Marshal(message)
if err != nil {
logger.Errorf("Error for decoding json data: %v", err.Error())
return
}
err = client.(*core.WsClient).Send(msg)
err = client.(*types.WsClient).Send(msg)
if err != nil {
logger.Errorf("Error for reply message: %v", err.Error())
}
}
// 回复客户端一条完整的消息
func replyMessage(ws core.Client, message string) {
func replyMessage(ws types.Client, message string) {
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
@@ -444,10 +445,10 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
// StopGenerate 停止生成
func (h *ChatHandler) StopGenerate(c *gin.Context) {
chatId := c.Query("chat_id")
if cancel, ok := h.app.ReqCancelFunc[chatId]; ok {
cancel()
delete(h.app.ReqCancelFunc, chatId)
sessionId := c.Query("session_id")
if h.app.ReqCancelFunc.Has(sessionId) {
h.app.ReqCancelFunc.Get(sessionId)()
h.app.ReqCancelFunc.Delete(sessionId)
}
resp.SUCCESS(c, types.OkMsg)
}

View File

@@ -89,7 +89,7 @@ func (h *ChatHandler) Remove(c *gin.Context) {
}
// 清空会话上下文
delete(h.app.ChatContexts, chatId)
h.app.ChatContexts.Delete(chatId)
resp.SUCCESS(c, types.OkMsg)
}
@@ -144,7 +144,7 @@ func (h *ChatHandler) Clear(c *gin.Context) {
logger.Warnf("Failed to delele chat history for ChatID: %s", chat.ChatId)
}
// 清空会话上下文
delete(h.app.ChatContexts, chat.ChatId)
h.app.ChatContexts.Delete(chat.ChatId)
}
// 删除所有的会话记录
res = h.db.Where("user_id = ?", user.Id).Delete(&model.ChatItem{})

View File

@@ -168,7 +168,7 @@ func (h *UserHandler) Login(c *gin.Context) {
}
// 记录登录信息在服务器
h.app.ChatSession[sessionId] = types.ChatSession{ClientIP: c.ClientIP(), UserId: user.Id, Username: data.Username, SessionId: sessionId}
h.app.ChatSession.Put(sessionId, types.ChatSession{ClientIP: c.ClientIP(), UserId: user.Id, Username: data.Username, SessionId: sessionId})
// 加载用户订阅的聊天角色
var roleMap map[string]int
@@ -237,9 +237,10 @@ func (h *UserHandler) Logout(c *gin.Context) {
logger.Error("Error for save session: ", err)
}
// 删除 websocket 会话列表
delete(h.app.ChatSession, sessionId)
h.app.ChatSession.Delete(sessionId)
// 关闭 socket 连接
if client, ok := h.app.ChatClients[sessionId]; ok {
client := h.app.ChatClients.Get(sessionId)
if client != nil {
client.Close()
}
resp.SUCCESS(c)
@@ -248,7 +249,8 @@ func (h *UserHandler) Logout(c *gin.Context) {
// Session 获取/验证会话
func (h *UserHandler) Session(c *gin.Context) {
sessionId := c.GetHeader(types.TokenSessionName)
if session, ok := h.app.ChatSession[sessionId]; ok && session.ClientIP == c.ClientIP() {
session := h.app.ChatSession.Get(sessionId)
if session.ClientIP == c.ClientIP() {
resp.SUCCESS(c, session)
} else {
resp.NotAuth(c)