mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 02:03:42 +08:00
fix: add lock map data structure, fixed bug for 'concurrent map writes'
This commit is contained in:
@@ -48,12 +48,13 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
logger.Error(err)
|
||||
return
|
||||
}
|
||||
sessionId := c.Query("sessionId")
|
||||
roleId := param.GetInt(c, "roleId", 0)
|
||||
chatId := c.Query("chatId")
|
||||
sessionId := c.Query("session_id")
|
||||
roleId := param.GetInt(c, "role_id", 0)
|
||||
chatId := c.Query("chat_id")
|
||||
chatModel := c.Query("model")
|
||||
session, ok := h.app.ChatSession[sessionId]
|
||||
if !ok {
|
||||
|
||||
session := h.app.ChatSession.Get(sessionId)
|
||||
if session.SessionId == "" {
|
||||
logger.Info("用户未登录")
|
||||
c.Abort()
|
||||
return
|
||||
@@ -70,7 +71,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
session.ChatId = chatId
|
||||
session.Model = chatModel
|
||||
logger.Infof("New websocket connected, IP: %s, Username: %s", c.Request.RemoteAddr, session.Username)
|
||||
client := core.NewWsClient(ws)
|
||||
client := types.NewWsClient(ws)
|
||||
var chatRole model.ChatRole
|
||||
res = h.db.First(&chatRole, roleId)
|
||||
if res.Error != nil || !chatRole.Enable {
|
||||
@@ -80,21 +81,21 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 保存会话连接
|
||||
h.app.ChatClients[chatId] = client
|
||||
h.app.ChatClients.Put(sessionId, client)
|
||||
go func() {
|
||||
for {
|
||||
_, message, err := client.Receive()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
client.Close()
|
||||
delete(h.app.ChatClients, chatId)
|
||||
delete(h.app.ReqCancelFunc, chatId)
|
||||
h.app.ChatClients.Delete(sessionId)
|
||||
h.app.ReqCancelFunc.Delete(sessionId)
|
||||
return
|
||||
}
|
||||
logger.Info("Receive a message: ", string(message))
|
||||
//replyMessage(client, "这是一条测试消息!")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
h.app.ReqCancelFunc[chatId] = cancel
|
||||
h.app.ReqCancelFunc.Put(sessionId, cancel)
|
||||
// 回复消息
|
||||
err = h.sendMessage(ctx, session, chatRole, string(message), client)
|
||||
if err != nil {
|
||||
@@ -109,7 +110,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
|
||||
func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws core.Client) error {
|
||||
func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws types.Client) error {
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
|
||||
var user model.User
|
||||
@@ -152,8 +153,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
// 加载聊天上下文
|
||||
var chatCtx []types.Message
|
||||
if userVo.ChatConfig.EnableContext {
|
||||
if v, ok := h.app.ChatContexts[session.ChatId]; ok {
|
||||
chatCtx = v
|
||||
if h.app.ChatContexts.Has(session.ChatId) {
|
||||
chatCtx = h.app.ChatContexts.Get(session.ChatId)
|
||||
} else {
|
||||
// 加载角色信息
|
||||
var messages []types.Message
|
||||
@@ -262,7 +263,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
if userVo.ChatConfig.EnableContext {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.app.ChatContexts[session.ChatId] = chatCtx
|
||||
h.app.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
@@ -348,9 +349,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
|
||||
replyMessage(ws, "当前会话上下文长度超出限制,已为您删减会话上下文!")
|
||||
// 只保留最近的三条记录
|
||||
chatContext := h.app.ChatContexts[session.ChatId]
|
||||
chatContext := h.app.ChatContexts.Get(session.ChatId)
|
||||
chatContext = chatContext[len(chatContext)-3:]
|
||||
h.app.ChatContexts[session.ChatId] = chatContext
|
||||
h.app.ChatContexts.Put(session.ChatId, chatContext)
|
||||
return h.sendMessage(ctx, session, role, prompt, ws)
|
||||
} else {
|
||||
replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
|
||||
@@ -410,20 +411,20 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
|
||||
}
|
||||
|
||||
// 回复客户片段端消息
|
||||
func replyChunkMessage(client core.Client, message types.WsMessage) {
|
||||
func replyChunkMessage(client types.Client, message types.WsMessage) {
|
||||
msg, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
logger.Errorf("Error for decoding json data: %v", err.Error())
|
||||
return
|
||||
}
|
||||
err = client.(*core.WsClient).Send(msg)
|
||||
err = client.(*types.WsClient).Send(msg)
|
||||
if err != nil {
|
||||
logger.Errorf("Error for reply message: %v", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 回复客户端一条完整的消息
|
||||
func replyMessage(ws core.Client, message string) {
|
||||
func replyMessage(ws types.Client, message string) {
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
|
||||
@@ -444,10 +445,10 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
|
||||
|
||||
// StopGenerate 停止生成
|
||||
func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
||||
chatId := c.Query("chat_id")
|
||||
if cancel, ok := h.app.ReqCancelFunc[chatId]; ok {
|
||||
cancel()
|
||||
delete(h.app.ReqCancelFunc, chatId)
|
||||
sessionId := c.Query("session_id")
|
||||
if h.app.ReqCancelFunc.Has(sessionId) {
|
||||
h.app.ReqCancelFunc.Get(sessionId)()
|
||||
h.app.ReqCancelFunc.Delete(sessionId)
|
||||
}
|
||||
resp.SUCCESS(c, types.OkMsg)
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ func (h *ChatHandler) Remove(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 清空会话上下文
|
||||
delete(h.app.ChatContexts, chatId)
|
||||
h.app.ChatContexts.Delete(chatId)
|
||||
resp.SUCCESS(c, types.OkMsg)
|
||||
}
|
||||
|
||||
@@ -144,7 +144,7 @@ func (h *ChatHandler) Clear(c *gin.Context) {
|
||||
logger.Warnf("Failed to delele chat history for ChatID: %s", chat.ChatId)
|
||||
}
|
||||
// 清空会话上下文
|
||||
delete(h.app.ChatContexts, chat.ChatId)
|
||||
h.app.ChatContexts.Delete(chat.ChatId)
|
||||
}
|
||||
// 删除所有的会话记录
|
||||
res = h.db.Where("user_id = ?", user.Id).Delete(&model.ChatItem{})
|
||||
|
||||
@@ -168,7 +168,7 @@ func (h *UserHandler) Login(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 记录登录信息在服务器
|
||||
h.app.ChatSession[sessionId] = types.ChatSession{ClientIP: c.ClientIP(), UserId: user.Id, Username: data.Username, SessionId: sessionId}
|
||||
h.app.ChatSession.Put(sessionId, types.ChatSession{ClientIP: c.ClientIP(), UserId: user.Id, Username: data.Username, SessionId: sessionId})
|
||||
|
||||
// 加载用户订阅的聊天角色
|
||||
var roleMap map[string]int
|
||||
@@ -237,9 +237,10 @@ func (h *UserHandler) Logout(c *gin.Context) {
|
||||
logger.Error("Error for save session: ", err)
|
||||
}
|
||||
// 删除 websocket 会话列表
|
||||
delete(h.app.ChatSession, sessionId)
|
||||
h.app.ChatSession.Delete(sessionId)
|
||||
// 关闭 socket 连接
|
||||
if client, ok := h.app.ChatClients[sessionId]; ok {
|
||||
client := h.app.ChatClients.Get(sessionId)
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
@@ -248,7 +249,8 @@ func (h *UserHandler) Logout(c *gin.Context) {
|
||||
// Session 获取/验证会话
|
||||
func (h *UserHandler) Session(c *gin.Context) {
|
||||
sessionId := c.GetHeader(types.TokenSessionName)
|
||||
if session, ok := h.app.ChatSession[sessionId]; ok && session.ClientIP == c.ClientIP() {
|
||||
session := h.app.ChatSession.Get(sessionId)
|
||||
if session.ClientIP == c.ClientIP() {
|
||||
resp.SUCCESS(c, session)
|
||||
} else {
|
||||
resp.NotAuth(c)
|
||||
|
||||
Reference in New Issue
Block a user