refactor websocket message protocol, keep the only connection for all clients

This commit is contained in:
RockYang
2024-09-27 17:50:54 +08:00
parent 478bc32ddd
commit 2debe7e927
29 changed files with 407 additions and 567 deletions

View File

@@ -1,4 +1,4 @@
package chatimpl
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
@@ -15,8 +15,6 @@ import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store/model"
@@ -33,14 +31,11 @@ import (
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
type ChatHandler struct {
handler.BaseHandler
BaseHandler
redis *redis.Client
uploadManager *oss.UploaderManager
licenseService *service.LicenseService
@@ -51,7 +46,7 @@ type ChatHandler struct {
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler {
return &ChatHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
BaseHandler: BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
licenseService: licenseService,
@@ -61,106 +56,6 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manag
}
}
// ChatHandle 处理聊天 WebSocket 请求
func (h *ChatHandler) ChatHandle(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
sessionId := c.Query("session_id")
roleId := h.GetInt(c, "role_id", 0)
chatId := c.Query("chat_id")
modelId := h.GetInt(c, "model_id", 0)
client := types.NewWsClient(ws)
var chatRole model.ChatRole
res := h.DB.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
utils.ReplyErrorMessage(client, "当前聊天角色不存在或者未启用,对话已关闭!!!")
c.Abort()
return
}
// if the role bind a model_id, use role's bind model_id
if chatRole.ModelId > 0 {
modelId = chatRole.ModelId
}
// get model info
var chatModel model.ChatModel
res = h.DB.First(&chatModel, modelId)
if res.Error != nil || chatModel.Enabled == false {
utils.ReplyErrorMessage(client, "当前AI模型暂未启用对话已关闭")
c.Abort()
return
}
session := &types.ChatSession{
SessionId: sessionId,
ClientIP: c.ClientIP(),
UserId: h.GetLoginUserId(c),
}
// use old chat data override the chat model and role ID
var chat model.ChatItem
res = h.DB.Where("chat_id = ?", chatId).First(&chat)
if res.Error == nil {
chatModel.Id = chat.ModelId
roleId = int(chat.RoleId)
}
session.ChatId = chatId
session.Model = types.ChatModel{
Id: chatModel.Id,
Name: chatModel.Name,
Value: chatModel.Value,
Power: chatModel.Power,
MaxTokens: chatModel.MaxTokens,
MaxContext: chatModel.MaxContext,
Temperature: chatModel.Temperature,
KeyId: chatModel.KeyId}
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
client.Close()
cancelFunc := h.ReqCancelFunc.Get(sessionId)
if cancelFunc != nil {
cancelFunc()
h.ReqCancelFunc.Delete(sessionId)
}
return
}
var message types.InputMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
logger.Infof("Receive a message:%+v", message)
session.Tools = message.Tools
session.Stream = message.Stream
ctx, cancel := context.WithCancel(context.Background())
h.ReqCancelFunc.Put(sessionId, cancel)
// 回复消息
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
if err != nil {
logger.Error(err)
utils.SendMessage(client, err.Error())
} else {
utils.SendChunkMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
logger.Infof("回答完毕: %v", message.Content)
}
}
}()
}
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
if !h.App.Debug {
defer func() {
@@ -206,7 +101,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
// 兼容 GPT-O1 模型
if strings.HasPrefix(session.Model.Value, "o1-") {
utils.ReplyContent(ws, "AI 正在思考...\n")
utils.SendChunkMsg(ws, "AI 正在思考...\n")
req.Stream = false
session.Start = time.Now().Unix()
} else {

View File

@@ -1,4 +1,4 @@
package chatimpl
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
@@ -28,31 +28,40 @@ func (h *ChatHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
var items = make([]vo.ChatItem, 0)
var chats []model.ChatItem
res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
if res.Error == nil {
var roleIds = make([]uint, 0)
for _, chat := range chats {
roleIds = append(roleIds, chat.RoleId)
}
var roles []model.ChatRole
res = h.DB.Find(&roles, roleIds)
if res.Error == nil {
roleMap := make(map[uint]model.ChatRole)
for _, role := range roles {
roleMap[role.Id] = role
}
h.DB.Where("user_id", userId).Order("id DESC").Find(&chats)
if len(chats) == 0 {
resp.SUCCESS(c, items)
return
}
for _, chat := range chats {
var item vo.ChatItem
err := utils.CopyObject(chat, &item)
if err == nil {
item.Id = chat.Id
item.Icon = roleMap[chat.RoleId].Icon
items = append(items, item)
}
}
}
var roleIds = make([]uint, 0)
var modelValues = make([]string, 0)
for _, chat := range chats {
roleIds = append(roleIds, chat.RoleId)
modelValues = append(modelValues, chat.Model)
}
var roles []model.ChatRole
var models []model.ChatModel
roleMap := make(map[uint]model.ChatRole)
modelMap := make(map[string]model.ChatModel)
h.DB.Where("id IN ?", roleIds).Find(&roles)
h.DB.Where("value IN ?", modelValues).Find(&models)
for _, role := range roles {
roleMap[role.Id] = role
}
for _, m := range models {
modelMap[m.Value] = m
}
for _, chat := range chats {
var item vo.ChatItem
err := utils.CopyObject(chat, &item)
if err == nil {
item.Id = chat.Id
item.Icon = roleMap[chat.RoleId].Icon
item.ModelId = modelMap[chat.Model].Id
items = append(items, item)
}
}
resp.SUCCESS(c, items)
}

View File

@@ -20,9 +20,7 @@ import (
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
)
type DallJobHandler struct {
@@ -45,49 +43,6 @@ func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service,
}
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *DallJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.dallService.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
client.Close()
h.dallService.Clients.Delete(uint(userId))
return
}
var message types.ReplyMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 DallE 心跳消息:", message.Content)
continue
}
}
}()
}
func (h *DallJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c)
if err != nil {

View File

@@ -19,7 +19,6 @@ import (
"geekai/store/model"
"geekai/utils"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"io"
"net/http"
@@ -43,55 +42,9 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us
}
}
func (h *MarkMapHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
return
}
// Generate 生成思维导图
func (h *MarkMapHandler) Generate(c *gin.Context) {
modelId := h.GetInt(c, "model_id", 0)
userId := h.GetInt(c, "user_id", 0)
client := types.NewWsClient(ws)
h.clients.Put(userId, client)
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
client.Close()
h.clients.Delete(userId)
return
}
var message types.ReplyMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 MarkMap 心跳消息:", message.Content)
continue
}
// change model
if message.Type == "model_id" {
modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0)
continue
}
logger.Info("Receive a message: ", message.Content)
err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
if err != nil {
logger.Error(err)
utils.ReplyErrorMessage(client, err.Error())
} else {
utils.SendMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
}
}
}()
}
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
@@ -170,13 +123,13 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
break
}
utils.SendChunkMessage(client, types.ReplyMessage{
Type: types.WsMsgTypeContent,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
utils.SendMsg(client, types.ReplyMessage{
Type: types.MsgTypeText,
Body: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
} // end for
utils.SendChunkMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
utils.SendMsg(client, types.ReplyMessage{Type: types.MsgTypeEnd})
} else {
body, _ := io.ReadAll(response.Body)

View File

@@ -19,12 +19,10 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
)
@@ -65,27 +63,6 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *MidJourneyHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.mjService.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
// Image 创建一个绘画任务
func (h *MidJourneyHandler) Image(c *gin.Context) {
var data struct {

View File

@@ -1,4 +1,4 @@
package chatimpl
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
@@ -108,7 +108,7 @@ func (h *ChatHandler) sendOpenAiMessage(
}
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
utils.SendMessage(ws, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。")
utils.SendChunkMsg(ws, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。")
break
}
@@ -136,7 +136,7 @@ func (h *ChatHandler) sendOpenAiMessage(
if res.Error == nil {
toolCall = true
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
utils.SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: callMsg})
utils.SendChunkMsg(ws, callMsg)
contents = append(contents, callMsg)
}
continue
@@ -153,10 +153,7 @@ func (h *ChatHandler) sendOpenAiMessage(
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content))
utils.SendChunkMessage(ws, types.ReplyMessage{
Type: types.WsMsgTypeContent,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
utils.SendChunkMsg(ws, responseBody.Choices[0].Delta.Content)
}
} // end for
@@ -174,7 +171,7 @@ func (h *ChatHandler) sendOpenAiMessage(
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
params["user_id"] = userVo.Id
var apiRes types.BizVo
r, err := req2.C().R().SetHeader("Content-Type", "application/json").
r, err := req2.C().R().SetHeader("Body-Type", "application/json").
SetHeader("Authorization", function.Token).
SetBody(params).
SetSuccessResult(&apiRes).Post(function.Action)
@@ -185,19 +182,13 @@ func (h *ChatHandler) sendOpenAiMessage(
errMsg = r.Status
}
if errMsg != "" || apiRes.Code != types.Success {
msg := "调用函数工具出错:" + apiRes.Message + errMsg
utils.SendChunkMessage(ws, types.ReplyMessage{
Type: types.WsMsgTypeContent,
Content: msg,
})
contents = append(contents, msg)
errMsg = "调用函数工具出错:" + apiRes.Message + errMsg
contents = append(contents, errMsg)
} else {
utils.SendChunkMessage(ws, types.ReplyMessage{
Type: types.WsMsgTypeContent,
Content: apiRes.Data,
})
contents = append(contents, utils.InterfaceToString(apiRes.Data))
errMsg = utils.InterfaceToString(apiRes.Data)
contents = append(contents, errMsg)
}
utils.SendChunkMsg(ws, errMsg)
}
// 消息发送成功
@@ -226,7 +217,7 @@ func (h *ChatHandler) sendOpenAiMessage(
if strings.HasPrefix(req.Model, "o1-") {
content = fmt.Sprintf("AI思考结束耗时%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
}
utils.SendMessage(ws, content)
utils.SendChunkMsg(ws, content)
respVo.Usage.Prompt = prompt
respVo.Usage.Content = content
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now())

View File

@@ -19,11 +19,8 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
@@ -59,27 +56,6 @@ func NewSdJobHandler(app *core.AppServer,
}
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *SdJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.sdService.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c)
if err != nil {

View File

@@ -19,9 +19,7 @@ import (
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"time"
)
@@ -44,27 +42,6 @@ func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, upl
}
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *SunoHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.sunoService.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *SunoHandler) Create(c *gin.Context) {
var data struct {

View File

@@ -19,7 +19,7 @@ func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.GeekP
}
func (h *TestHandler) SseTest(c *gin.Context) {
//c.Header("Content-Type", "text/event-stream")
//c.Header("Body-Type", "text/event-stream")
//c.Header("Cache-Control", "no-cache")
//c.Header("Connection", "keep-alive")
//

View File

@@ -19,9 +19,7 @@ import (
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"time"
)
@@ -44,27 +42,6 @@ func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, u
}
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *VideoHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.videoService.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *VideoHandler) LumaCreate(c *gin.Context) {
var data struct {

View File

@@ -8,6 +8,7 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"geekai/core"
"geekai/core/types"
"geekai/service"
@@ -15,6 +16,7 @@ import (
"geekai/utils"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
)
@@ -22,12 +24,14 @@ import (
type WebsocketHandler struct {
BaseHandler
wsService *service.WebsocketService
wsService *service.WebsocketService
chatHandler *ChatHandler
}
func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService) *WebsocketHandler {
func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *gorm.DB, chatHandler *ChatHandler) *WebsocketHandler {
return &WebsocketHandler{
BaseHandler: BaseHandler{App: app},
BaseHandler: BaseHandler{App: app, DB: db},
chatHandler: chatHandler,
wsService: s,
}
}
@@ -40,9 +44,9 @@ func (h *WebsocketHandler) Client(c *gin.Context) {
return
}
userId := h.GetInt(c, "user_id", 0)
clientId := c.Query("client")
client := types.NewWsClient(ws)
clientId := c.Query("client_id")
client := types.NewWsClient(ws, clientId)
userId := h.GetLoginUserId(c)
if userId == 0 {
_ = client.Send([]byte("Invalid user_id"))
c.Abort()
@@ -63,6 +67,8 @@ func (h *WebsocketHandler) Client(c *gin.Context) {
if err != nil {
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
client.Close()
h.wsService.Clients.Delete(clientId)
break
}
var message types.InputMessage
@@ -72,12 +78,66 @@ func (h *WebsocketHandler) Client(c *gin.Context) {
}
logger.Infof("Receive a message:%+v", message)
if message.Type == types.WsMsgTypePing {
_ = client.Send([]byte(`{"type":"pong"}`))
if message.Type == types.MsgTypePing {
utils.SendChannelMsg(client, types.ChPing, "pong")
continue
}
switch message.Channel {
// 当前只处理聊天消息,其他消息全部丢弃
var chatMessage types.ChatMessage
err = utils.JsonDecode(utils.JsonEncode(message.Body), &chatMessage)
if err != nil || message.Channel != types.ChChat {
logger.Warnf("invalid message body:%+v", message.Body)
continue
}
var chatRole model.ChatRole
err = h.DB.First(&chatRole, chatMessage.RoleId).Error
if err != nil || !chatRole.Enable {
utils.SendAndFlush(client, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!!!")
continue
}
// if the role bind a model_id, use role's bind model_id
if chatRole.ModelId > 0 {
chatMessage.RoleId = chatRole.ModelId
}
// get model info
var chatModel model.ChatModel
err = h.DB.Where("id", chatMessage.ModelId).First(&chatModel).Error
if err != nil || chatModel.Enabled == false {
utils.SendAndFlush(client, "当前AI模型暂未启用请更换模型后再发起对话")
continue
}
session := &types.ChatSession{
ClientIP: c.ClientIP(),
UserId: userId,
}
// use old chat data override the chat model and role ID
var chat model.ChatItem
h.DB.Where("chat_id", chatMessage.ChatId).First(&chat)
if chat.Id > 0 {
chatModel.Id = chat.ModelId
chatMessage.RoleId = int(chat.RoleId)
}
session.ChatId = chatMessage.ChatId
session.Tools = chatMessage.Tools
session.Stream = chatMessage.Stream
// 复制模型数据
err = utils.CopyObject(chatModel, &session.Model)
if err != nil {
logger.Error(err, chatModel)
}
ctx, cancel := context.WithCancel(context.Background())
h.chatHandler.ReqCancelFunc.Put(clientId, cancel)
err = h.chatHandler.sendMessage(ctx, session, chatRole, chatMessage.Content, client)
if err != nil {
logger.Error(err)
utils.SendAndFlush(client, err.Error())
} else {
utils.SendMsg(client, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd})
logger.Infof("回答完毕: %v", message.Body)
}
}