mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-11 03:33:48 +08:00
fix: add lock map data structure, fixed bug for 'concurrent map writes'
This commit is contained in:
@@ -48,12 +48,13 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
logger.Error(err)
|
||||
return
|
||||
}
|
||||
sessionId := c.Query("sessionId")
|
||||
roleId := param.GetInt(c, "roleId", 0)
|
||||
chatId := c.Query("chatId")
|
||||
sessionId := c.Query("session_id")
|
||||
roleId := param.GetInt(c, "role_id", 0)
|
||||
chatId := c.Query("chat_id")
|
||||
chatModel := c.Query("model")
|
||||
session, ok := h.app.ChatSession[sessionId]
|
||||
if !ok {
|
||||
|
||||
session := h.app.ChatSession.Get(sessionId)
|
||||
if session.SessionId == "" {
|
||||
logger.Info("用户未登录")
|
||||
c.Abort()
|
||||
return
|
||||
@@ -70,7 +71,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
session.ChatId = chatId
|
||||
session.Model = chatModel
|
||||
logger.Infof("New websocket connected, IP: %s, Username: %s", c.Request.RemoteAddr, session.Username)
|
||||
client := core.NewWsClient(ws)
|
||||
client := types.NewWsClient(ws)
|
||||
var chatRole model.ChatRole
|
||||
res = h.db.First(&chatRole, roleId)
|
||||
if res.Error != nil || !chatRole.Enable {
|
||||
@@ -80,21 +81,21 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 保存会话连接
|
||||
h.app.ChatClients[chatId] = client
|
||||
h.app.ChatClients.Put(sessionId, client)
|
||||
go func() {
|
||||
for {
|
||||
_, message, err := client.Receive()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
client.Close()
|
||||
delete(h.app.ChatClients, chatId)
|
||||
delete(h.app.ReqCancelFunc, chatId)
|
||||
h.app.ChatClients.Delete(sessionId)
|
||||
h.app.ReqCancelFunc.Delete(sessionId)
|
||||
return
|
||||
}
|
||||
logger.Info("Receive a message: ", string(message))
|
||||
//replyMessage(client, "这是一条测试消息!")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
h.app.ReqCancelFunc[chatId] = cancel
|
||||
h.app.ReqCancelFunc.Put(sessionId, cancel)
|
||||
// 回复消息
|
||||
err = h.sendMessage(ctx, session, chatRole, string(message), client)
|
||||
if err != nil {
|
||||
@@ -109,7 +110,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
|
||||
func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws core.Client) error {
|
||||
func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws types.Client) error {
|
||||
promptCreatedAt := time.Now() // 记录提问时间
|
||||
|
||||
var user model.User
|
||||
@@ -152,8 +153,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
// 加载聊天上下文
|
||||
var chatCtx []types.Message
|
||||
if userVo.ChatConfig.EnableContext {
|
||||
if v, ok := h.app.ChatContexts[session.ChatId]; ok {
|
||||
chatCtx = v
|
||||
if h.app.ChatContexts.Has(session.ChatId) {
|
||||
chatCtx = h.app.ChatContexts.Get(session.ChatId)
|
||||
} else {
|
||||
// 加载角色信息
|
||||
var messages []types.Message
|
||||
@@ -262,7 +263,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
if userVo.ChatConfig.EnableContext {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.app.ChatContexts[session.ChatId] = chatCtx
|
||||
h.app.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
@@ -348,9 +349,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
|
||||
replyMessage(ws, "当前会话上下文长度超出限制,已为您删减会话上下文!")
|
||||
// 只保留最近的三条记录
|
||||
chatContext := h.app.ChatContexts[session.ChatId]
|
||||
chatContext := h.app.ChatContexts.Get(session.ChatId)
|
||||
chatContext = chatContext[len(chatContext)-3:]
|
||||
h.app.ChatContexts[session.ChatId] = chatContext
|
||||
h.app.ChatContexts.Put(session.ChatId, chatContext)
|
||||
return h.sendMessage(ctx, session, role, prompt, ws)
|
||||
} else {
|
||||
replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
|
||||
@@ -410,20 +411,20 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
|
||||
}
|
||||
|
||||
// 回复客户片段端消息
|
||||
func replyChunkMessage(client core.Client, message types.WsMessage) {
|
||||
func replyChunkMessage(client types.Client, message types.WsMessage) {
|
||||
msg, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
logger.Errorf("Error for decoding json data: %v", err.Error())
|
||||
return
|
||||
}
|
||||
err = client.(*core.WsClient).Send(msg)
|
||||
err = client.(*types.WsClient).Send(msg)
|
||||
if err != nil {
|
||||
logger.Errorf("Error for reply message: %v", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 回复客户端一条完整的消息
|
||||
func replyMessage(ws core.Client, message string) {
|
||||
func replyMessage(ws types.Client, message string) {
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
|
||||
@@ -444,10 +445,10 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
|
||||
|
||||
// StopGenerate 停止生成
|
||||
func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
||||
chatId := c.Query("chat_id")
|
||||
if cancel, ok := h.app.ReqCancelFunc[chatId]; ok {
|
||||
cancel()
|
||||
delete(h.app.ReqCancelFunc, chatId)
|
||||
sessionId := c.Query("session_id")
|
||||
if h.app.ReqCancelFunc.Has(sessionId) {
|
||||
h.app.ReqCancelFunc.Get(sessionId)()
|
||||
h.app.ReqCancelFunc.Delete(sessionId)
|
||||
}
|
||||
resp.SUCCESS(c, types.OkMsg)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user