mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 01:06:39 +08:00
146 lines
4.1 KiB
Go
146 lines
4.1 KiB
Go
package handler
|
||
|
||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||
// * Use of this source code is governed by a Apache-2.0 license
|
||
// * that can be found in the LICENSE file.
|
||
// * @Author yangjian102621@163.com
|
||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||
|
||
import (
|
||
"context"
|
||
"geekai/core"
|
||
"geekai/core/types"
|
||
"geekai/service"
|
||
"geekai/store/model"
|
||
"geekai/utils"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/gorilla/websocket"
|
||
"gorm.io/gorm"
|
||
"net/http"
|
||
)
|
||
|
||
// Websocket 连接处理 handler
|
||
|
||
type WebsocketHandler struct {
|
||
BaseHandler
|
||
wsService *service.WebsocketService
|
||
chatHandler *ChatHandler
|
||
}
|
||
|
||
func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *gorm.DB, chatHandler *ChatHandler) *WebsocketHandler {
|
||
return &WebsocketHandler{
|
||
BaseHandler: BaseHandler{App: app, DB: db},
|
||
chatHandler: chatHandler,
|
||
wsService: s,
|
||
}
|
||
}
|
||
|
||
func (h *WebsocketHandler) Client(c *gin.Context) {
|
||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||
if err != nil {
|
||
logger.Error(err)
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
clientId := c.Query("client_id")
|
||
client := types.NewWsClient(ws, clientId)
|
||
userId := h.GetLoginUserId(c)
|
||
if userId == 0 {
|
||
_ = client.Send([]byte("Invalid user_id"))
|
||
c.Abort()
|
||
return
|
||
}
|
||
var user model.User
|
||
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
|
||
_ = client.Send([]byte("Invalid user_id"))
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
h.wsService.Clients.Put(clientId, client)
|
||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||
go func() {
|
||
for {
|
||
_, msg, err := client.Receive()
|
||
if err != nil {
|
||
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
|
||
client.Close()
|
||
h.wsService.Clients.Delete(clientId)
|
||
break
|
||
}
|
||
|
||
var message types.InputMessage
|
||
err = utils.JsonDecode(string(msg), &message)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
|
||
logger.Debugf("Receive a message:%+v", message)
|
||
if message.Type == types.MsgTypePing {
|
||
utils.SendChannelMsg(client, types.ChPing, "pong")
|
||
continue
|
||
}
|
||
|
||
// 当前只处理聊天消息,其他消息全部丢弃
|
||
var chatMessage types.ChatMessage
|
||
err = utils.JsonDecode(utils.JsonEncode(message.Body), &chatMessage)
|
||
if err != nil || message.Channel != types.ChChat {
|
||
logger.Warnf("invalid message body:%+v", message.Body)
|
||
continue
|
||
}
|
||
var chatRole model.ChatRole
|
||
err = h.DB.First(&chatRole, chatMessage.RoleId).Error
|
||
if err != nil || !chatRole.Enable {
|
||
utils.SendAndFlush(client, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!!!")
|
||
continue
|
||
}
|
||
// if the role bind a model_id, use role's bind model_id
|
||
if chatRole.ModelId > 0 {
|
||
chatMessage.RoleId = chatRole.ModelId
|
||
}
|
||
// get model info
|
||
var chatModel model.ChatModel
|
||
err = h.DB.Where("id", chatMessage.ModelId).First(&chatModel).Error
|
||
if err != nil || chatModel.Enabled == false {
|
||
utils.SendAndFlush(client, "当前AI模型暂未启用,请更换模型后再发起对话!!!")
|
||
continue
|
||
}
|
||
|
||
session := &types.ChatSession{
|
||
ClientIP: c.ClientIP(),
|
||
UserId: userId,
|
||
}
|
||
|
||
// use old chat data override the chat model and role ID
|
||
var chat model.ChatItem
|
||
h.DB.Where("chat_id", chatMessage.ChatId).First(&chat)
|
||
if chat.Id > 0 {
|
||
chatModel.Id = chat.ModelId
|
||
chatMessage.RoleId = int(chat.RoleId)
|
||
}
|
||
|
||
session.ChatId = chatMessage.ChatId
|
||
session.Tools = chatMessage.Tools
|
||
session.Stream = chatMessage.Stream
|
||
// 复制模型数据
|
||
err = utils.CopyObject(chatModel, &session.Model)
|
||
if err != nil {
|
||
logger.Error(err, chatModel)
|
||
}
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
h.chatHandler.ReqCancelFunc.Put(clientId, cancel)
|
||
err = h.chatHandler.sendMessage(ctx, session, chatRole, chatMessage.Content, client)
|
||
if err != nil {
|
||
logger.Error(err)
|
||
utils.SendAndFlush(client, err.Error())
|
||
} else {
|
||
utils.SendMsg(client, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd})
|
||
logger.Infof("回答完毕: %v", message.Body)
|
||
}
|
||
|
||
}
|
||
}()
|
||
}
|