mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-10 11:13:42 +08:00
refactor websocket message protocol, keep the only connection for all clients
This commit is contained in:
@@ -19,7 +19,6 @@ import (
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -43,55 +42,9 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us
|
||||
}
|
||||
}
|
||||
|
||||
func (h *MarkMapHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return
|
||||
}
|
||||
// Generate 生成思维导图
|
||||
func (h *MarkMapHandler) Generate(c *gin.Context) {
|
||||
|
||||
modelId := h.GetInt(c, "model_id", 0)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
h.clients.Put(userId, client)
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := client.Receive()
|
||||
if err != nil {
|
||||
client.Close()
|
||||
h.clients.Delete(userId)
|
||||
return
|
||||
}
|
||||
|
||||
var message types.ReplyMessage
|
||||
err = utils.JsonDecode(string(msg), &message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 心跳消息
|
||||
if message.Type == "heartbeat" {
|
||||
logger.Debug("收到 MarkMap 心跳消息:", message.Content)
|
||||
continue
|
||||
}
|
||||
// change model
|
||||
if message.Type == "model_id" {
|
||||
modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("Receive a message: ", message.Content)
|
||||
err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
utils.ReplyErrorMessage(client, err.Error())
|
||||
} else {
|
||||
utils.SendMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
|
||||
@@ -170,13 +123,13 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
|
||||
break
|
||||
}
|
||||
|
||||
utils.SendChunkMessage(client, types.ReplyMessage{
|
||||
Type: types.WsMsgTypeContent,
|
||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||
utils.SendMsg(client, types.ReplyMessage{
|
||||
Type: types.MsgTypeText,
|
||||
Body: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||
})
|
||||
} // end for
|
||||
|
||||
utils.SendChunkMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
|
||||
utils.SendMsg(client, types.ReplyMessage{Type: types.MsgTypeEnd})
|
||||
|
||||
} else {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
|
||||
Reference in New Issue
Block a user